In [1]:
import pickle
import gzip
import os
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, Subset, DataLoader
#import torchvideo.transforms as VT
import torchvision.transforms as IT
import cv2
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import math
from transformers import BertConfig, BertForSequenceClassification, EncoderDecoderModel, BertModel, AutoConfig
from transformers import AutoTokenizer, AutoModelWithLMHead, BertTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import get_scheduler
from torchviz import make_dot
from datasets import list_datasets, load_metric
import matplotlib.pyplot as plt
import numpy as np

In [2]:
train_dir = r"C:\Sign-Language-Recognition\Datasets\PHOENIX14T_small\archive\videos_phoenix\videos\train"
files = os.listdir(train_dir)

im_size = 224
vid_batch_size = 8
frame_batch_size = 8
slrt_input_len = 200
max_vid_len = 200
frame_encoding_size = 1024
recog_loss_weight = 0.01
max_gradient_norm = 5

assert((vid_batch_size * max_vid_len) % frame_batch_size == 0)

device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
device

device(type='cuda')

In [None]:
# lengths = []
# for vid in tqdm(files):
#     cap = cv2.VideoCapture(os.path.join(train_dir, vid))
#     length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
#     lengths.append(length)

In [None]:
# plt.hist(lengths)

In [None]:
vid = torchvision.io.read_video(os.path.join(train_dir, files[0])) # vid[0] = time x height x width x channels
print(vid[0].shape)
vid = F.interpolate(vid[0].permute(0, 3, 1, 2), size=299) # vid = time x channels x height x width
IT.ToPILImage()(vid[10])

In [3]:
class PhoenixDataset(Dataset):
    def __init__(self, annotations_path):
        
        with gzip.open(annotations_path, 'rb') as f:
            self.annotations = pickle.load(f)
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, index):
        
        #vid = torchvision.io.read_video(os.path.join(self.vid_dir, self.vid_files[index]))[0] # vid = time x height x width x channels
        #vid = vid[:max_vid_len]
        #vid_len = vid.shape[0]
        #vid = F.interpolate(vid.permute(0, 3, 1, 2), size=self.im_size)
        #vid = IT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(vid.type(torch.FloatTensor))
        #vid = torch.cat((vid, torch.zeros(max_vid_len - vid.shape[0], *vid.shape[1:])), dim=0)
        sign, gloss, text = self.annotations[index]['sign'], self.annotations[index]['gloss'], self.annotations[index]['text']
        sign = sign[:slrt_input_len]
        sign = torch.cat([sign, torch.zeros(slrt_input_len - sign.shape[0], *sign.shape[1:])], dim=0)

        return (sign, gloss, text)

In [4]:
phoenix_train = PhoenixDataset(
    annotations_path = r"C:\Sign-Language-Recognition\Datasets\PHOENIX14T\phoenix14t.pami0.train",
)

indices = np.random.choice(len(phoenix_train), 500, replace=False)
phoenix_train_small = Subset(phoenix_train, indices)

phoenix_val = PhoenixDataset(
    annotations_path = r"C:\Sign-Language-Recognition\Datasets\PHOENIX14T\phoenix14t.pami0.dev",
)

phoenix_test = PhoenixDataset(
    annotations_path = r"C:\Sign-Language-Recognition\Datasets\PHOENIX14T\phoenix14t.pami0.test",
)

phoenix_train_loader = DataLoader(dataset=phoenix_train, batch_size=vid_batch_size, shuffle=True, num_workers=0)
phonix_train_small_loader = DataLoader(dataset=phoenix_train_small, batch_size=vid_batch_size, shuffle=True, num_workers=0)
phoenix_val_loader = DataLoader(dataset=phoenix_val, batch_size=vid_batch_size, shuffle=True, num_workers=0)
phoenix_test_loader = DataLoader(dataset=phoenix_test, batch_size=vid_batch_size, shuffle=True, num_workers=0)

In [5]:
len(phoenix_train), len(phoenix_train_small), len(phoenix_val), len(phoenix_test)

(7096, 500, 519, 642)

In [None]:
phoenix_train[3][0].shape

In [None]:
# %%capture
# frame_encoder = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b0', pretrained=True)
# for i, (n, p) in enumerate(frame_encoder.named_parameters()):
#     if i > 150:
#         break
#     p.requires_grad_(False)
# frame_encoder.classifier = nn.Sequential(nn.Flatten(), nn.Linear(62720, frame_encoding_size))
# cnn_optimizer = torch.optim.AdamW(frame_encoder.parameters(), lr=1e-3)
# frame_encoder.to(device)

In [None]:
# frame_encoder

In [None]:
# len([n for n, p in frame_encoder.named_parameters()])

In [None]:
test_batch = next(iter(phoenix_train_loader))

In [None]:
test_batch

In [None]:
# batch_v_len = torch.max(batch[0])
# vid_encoder_input = torch.empty(vid_batch_size, batch_v_len, frame_encoding_size, device=device)
# for vb in tqdm(range(vid_batch_size)):
#     v_len = batch[0][vb].item()
#     single_vid_encoding = torch.empty(v_len, frame_encoding_size, device=device)
#     for i in range(math.ceil(v_len/frame_batch_size)):
#         #print("1: {:,}".format(torch.cuda.memory_allocated(0)))
#         start_frame = i * frame_batch_size
#         #print("2: {:,}".format(torch.cuda.memory_allocated(0)))
#         end_frame = min((i + 1) * frame_batch_size, v_len)
#         print("3: {:,}".format(torch.cuda.memory_allocated(0)))
#         fe_input = batch[1][vb][start_frame:end_frame].to(device)
#         #print("4: {:,}".format(torch.cuda.memory_allocated(0)))
#         encoded_frames = frame_encoder(fe_input)
#         #print("5: {:,}".format(torch.cuda.memory_allocated(0)))
#         single_vid_encoding[start_frame:end_frame, :] = encoded_frames.unsqueeze(0)
#         #print("6: {:,}".format(torch.cuda.memory_allocated(0)))
#         del fe_input
#         torch.cuda.empty_cache()
#         #print("7: {:,}".format(torch.cuda.memory_allocated(0)))
#         del encoded_frames
#         torch.cuda.empty_cache()
#         #print("8: {:,}".format(torch.cuda.memory_allocated(0)))
#     single_vid_encoding = torch.cat((single_vid_encoding, torch.zeros(batch_v_len - v_len, frame_encoding_size, device=device)), dim = 0)
#     vid_encoder_input[vb, :, :] = single_vid_encoding
#     del single_vid_encoding
# vid_encoder_input = torch.cat((vid_encoder_input, torch.zeros(vid_batch_size, slrt_input_len - vid_encoder_input.shape[1], frame_encoding_size, device=device)), dim = 1)

In [None]:
# vid_encoder_input.shape

In [None]:
# make_dot(vid_encoder_input)

In [None]:
german_tokenizer = AutoTokenizer.from_pretrained("dbmdz/german-gpt2")
german_gpt = AutoModelWithLMHead.from_pretrained("dbmdz/german-gpt2")

In [None]:
vocab_size = len(german_tokenizer)
vocab_size

In [None]:
encoded = german_tokenizer("MORGEN TEMPERATUR ELF SAUER LAND BIS MAXIMAL EINS ZWANZIG BERG OST", return_tensors="pt").input_ids

In [None]:
german_tokenizer.decode(encoded[0])

In [None]:
# batch[0].shape

In [None]:
# bert_config = BertConfig(
#     hidden_size = frame_encoding_size,
#     num_hidden_layers = 1,
#     num_attention_heads = 8,
#     intermetiate_size = frame_encoding_size * 2,
#     max_length = slrt_input_len
# )
# bert = BertModel(bert_config, add_pooling_layer=False)

In [None]:
# bert

In [None]:
# make_dot(encoder(test_batch[0]))

In [None]:
german_gpt

In [None]:
german_gpt.transformer.h

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self):
        super(TransformerEncoder, self).__init__()

        self.positional_encoding = PositionalEncoding(frame_encoding_size)
        self.encoder_layer = nn.TransformerEncoderLayer(
                            d_model = frame_encoding_size,
                            nhead = 8,
                            dim_feedforward = frame_encoding_size * 2,
                            batch_first = True
                        )
        self.layer_norm = nn.LayerNorm(frame_encoding_size)
        
    def forward(self, embeddings):
        x = self.positional_encoding(embeddings)
        x = self.encoder_layer(x)
        return self.layer_norm(x)

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self):
        super(TransformerDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, frame_encoding_size)
        self.positional_encoding = PositionalEncoding(frame_encoding_size)
        self.decoder_layer = nn.TransformerDecoderLayer(
                            d_model = frame_encoding_size,
                            nhead = 8,
                            dim_feedforward = frame_encoding_size * 2,
                            batch_first = True
                        )
        self.linear_final = nn.Linear(frame_encoding_size, vocab_size)
        
    def forward(self, encoder_output, tgt, tgt_mask):
        x = self.word_embedding(tgt)
        x = self.positional_encoding(x)
        x = self.decoder_layer(x, encoder_output, tgt_mask)
        return self.linear_final(x)

In [None]:
t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
t5.to(device)

In [None]:
test_batch[2]

In [None]:
test_batch[0].shape

In [None]:
vocab_size = len(t5_tokenizer)
vocab_size

In [None]:
frame_to_enc = nn.Linear(1024, 512)
enc_to_gloss_probs = nn.Sequential(
    nn.Linear(512, vocab_size + 1),
    nn.LogSoftmax(dim = -1)
).to(device)
ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True)

In [None]:
t5_opt = torch.optim.AdamW([
    {'params': t5.parameters(), 'lr': 1e-4},
    {'params': enc_to_gloss_probs.parameters(), 'lr': 5e-3},
    {'params': frame_to_enc.parameters(), 'lr': 5e-3}
])

In [None]:
frame_embeds = frame_to_enc(test_batch[0].view(-1, 1024)).view(vid_batch_size, 200, 512).to(device)
gloss_ids = t5_tokenizer(list(test_batch[1]), return_tensors='pt', padding='max_length', max_length=200).input_ids.to(device)
translations_ids = t5_tokenizer(list(test_batch[2]), return_tensors='pt', padding='max_length', max_length=200).input_ids.to(device)
outputs = t5(inputs_embeds=frame_embeds, labels=translations_ids)
print(outputs.loss)
# outputs.loss.backward()
# t5_opt.step()

In [None]:
outputs.encoder_last_hidden_state.shape

In [None]:
t5_tokenizer.decode(0), t5_tokenizer.decode(vocab_size)

In [None]:
gloss_probs = enc_to_gloss_probs(outputs.encoder_last_hidden_state.view(-1, 512)).view(vid_batch_size, 200, vocab_size).permute(1, 0, 2)
input_lengths = torch.full(size=(vid_batch_size,), fill_value=200, dtype=torch.long).to(device)
glosses = t5_tokenizer(list(batch[1]), return_tensors='pt', padding='max_length', max_length=200)
gloss_ids = glosses.input_ids.to(device)
gloss_lengths = torch.sum(glosses.attention_mask, dim = -1).to(device)
recog_loss = ctc_loss(gloss_probs, gloss_ids, input_lengths, gloss_lengths)
print(recog_loss)

In [None]:
configuration = AutoConfig.from_pretrained('t5-small')
configuration

In [None]:
configuration.dropout_rate = 0.3
model = T5ForConditionalGeneration.from_pretrained("t5-small", config=configuration)

In [None]:
model

In [15]:
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
vocab_size = len(t5_tokenizer)
ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True)
vocab_size

32100

In [16]:
def initialize_model_and_opt(config):
    hf_config = AutoConfig.from_pretrained('t5-small')
    
    if "dropout_rate" in config.keys():
        hf_config.dropout_rate = config['dropout_rate']
    
    if "layer_norm_epsilon" in config.keys():
        hf_config.layer_norm_epsilon = config['layer_norm_epsilon']
    
    model = T5ForConditionalGeneration.from_pretrained("t5-small", config=hf_config).to(device)
    
    frame_to_enc = nn.Linear(1024, 512).to(device)
    enc_to_gloss_probs = nn.Sequential(
        nn.Linear(512, vocab_size + 1),
        nn.LogSoftmax(dim = -1)
    ).to(device)
    
    extra_layers = {
        'frame_to_enc': frame_to_enc,
        'enc_to_gloss_probs': enc_to_gloss_probs,
    }
    opt = torch.optim.AdamW([
        {'params': model.parameters(), 'lr': config['model-lr']},
        {'params': enc_to_gloss_probs.parameters(), 'lr': config['extra-layer-lr']},
        {'params': frame_to_enc.parameters(), 'lr': config['extra-layer-lr']}
    ])
    
    return model, extra_layers, opt

In [None]:
# t5_tokenizer.decode(t5.generate(input_ids=input_ids)[0])

In [17]:
bleu = load_metric("bleu")

In [18]:
def evaluate(model, frame_to_enc, data_loader, metric):
    progress = tqdm(range(len(data_loader)))
    model.eval()
    for s, g, t in data_loader:
        s = s.to(device)
        frame_embeds = frame_to_enc(s.view(-1, 1024)).view(len(s), 200, 512)
        translations_ids = t5_tokenizer(list(t), return_tensors='pt', padding='max_length', max_length=200).input_ids.to(device)
        outputs = model.generate(inputs_embeds=frame_embeds)
        
        predictions = [s.split(" ") for s in t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)]
        references = [[s.split(" ")] for s in t]
        metric.add_batch(predictions=predictions, references=references)
        
        progress.update(1)
    
    return metric.compute()

In [19]:
def train(model, extra_layers, optimizer, tokenizer, recog_loss_weight, dataloaders, n_epochs, save_folder, eval_every=300, should_print=True):
    trans_loss_history = {}
    recog_loss_history = {}
    total_loss_history = {}
    train_bleu_history = {}
    val_bleu_history = {}
    test_bleu_history = {}
    
    train_loader, train_small_loader, val_loader, test_loader = dataloaders['train'], dataloaders['train-small'], dataloaders['val'], dataloaders['test']
    frame_to_enc = extra_layers['frame_to_enc']
    enc_to_gloss_probs = extra_layers['enc_to_gloss_probs']
    
    for epoch in tqdm(range(n_epochs), desc = 'Epochs'):
        train_progress = tqdm(range(len(train_loader)), desc = 'Batches processed')
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()
            sign, gloss, text = batch
            sign = sign.to(device)
            frame_embeds = frame_to_enc(sign.view(-1, 1024)).view(len(sign), 200, 512)
            translations_ids = tokenizer(list(text), return_tensors='pt', padding='max_length', max_length=200).input_ids.to(device)
            outputs = model(inputs_embeds=frame_embeds, labels=translations_ids)

            gloss_probs = enc_to_gloss_probs(outputs.encoder_last_hidden_state.view(-1, 512)).view(vid_batch_size, 200, vocab_size+1).permute(1, 0, 2)
            input_lengths = torch.full(size=(vid_batch_size,), fill_value=200, dtype=torch.long).to(device)
            glosses = tokenizer(list(gloss), return_tensors='pt', padding='max_length', max_length=200)
            gloss_ids = glosses.input_ids.to(device)
            gloss_lengths = torch.sum(glosses.attention_mask, dim = -1).to(device)
            recog_loss = ctc_loss(gloss_probs, gloss_ids, input_lengths, gloss_lengths)

            loss = (1 - recog_loss_weight) * outputs.loss + recog_loss_weight * recog_loss
    #        loss = outputs.loss
            loss.backward()
            optimizer.step()

            if i % eval_every == 0:
                model.eval()
                trans_loss_history[str(epoch) + '-' + str(i)] = outputs.loss.item()
                recog_loss_history[str(epoch) + '-' + str(i)] = recog_loss.item()
                total_loss_history[str(epoch) + '-' + str(i)] = loss.item()
                
                train_bleu_history[str(epoch) + '-' + str(i)] = evaluate(model, frame_to_enc, train_small_loader, bleu)
                val_bleu_history[str(epoch) + '-' + str(i)] = evaluate(model, frame_to_enc, val_loader, bleu)
                test_bleu_history[str(epoch) + '-' + str(i)] = evaluate(model, frame_to_enc, test_loader, bleu)
                model.save_pretrained(os.path.join(save_folder, str(epoch) + '-' + str(i)))

                if should_print:
                    print("Epoch:", epoch, "Iteration:", i)
                    print("Train bleu: {:.2%}".format(train_bleu_history[str(epoch) + '-' + str(i)]['bleu']))
                    print("Val bleu: {:.2%}".format(val_bleu_history[str(epoch) + '-' + str(i)]['bleu']))
                    print("Test bleu: {:.2%}".format(test_bleu_history[str(epoch) + '-' + str(i)]['bleu']))
                    print("Translation loss: {:.3f}".format(outputs.loss.item()))
                    print("Recognition loss: {:.3f}".format(recog_loss.item()))
                    print("\n")
                    print("Gloss:", gloss[0])
                    print("Translation:", text[0])
                    print("Predicted translation:", tokenizer.decode(model.generate(inputs_embeds=frame_embeds)[0]))
                    print("========================================================================================")
                model.train()

            train_progress.update(1)
    
    histories = {
        'trans_loss': trans_loss_history,
        'recog_loss': recog_loss_history,
        'total_loss': total_loss_history,
        'train_bleu': train_bleu_history,
        'val_bleu': val_bleu_history,
        'test_bleu': test_bleu_history,
    }
    return histories

In [20]:
configs = {
    1: {
        'recog_loss_weight': 0.01,
        'dropout_rate': 0.1,
        'model-lr': 1e-4,
        'extra-layer-lr': 5e-3
    },
    
    2: {
        'recog_loss_weight': 0.1,
        'dropout_rate': 0.1,
        'model-lr': 1e-4,
        'extra-layer-lr': 5e-3
    },
    
    3: {
        'recog_loss_weight': 1.0,
        'dropout_rate': 0.1,
        'model-lr': 1e-4,
        'extra-layer-lr': 5e-3
    },
}

In [None]:
all_histories = []
for i, config in configs.items():
    
    print("Initializing model...")
    model, extra_layers, opt = initialize_model_and_opt(config)
    print("Done! Training...")
    histories = train(
        model = model,
        extra_layers = extra_layers,
        optimizer = opt,
        tokenizer = t5_tokenizer,
        recog_loss_weight = config['recog_loss_weight'],
        dataloaders = {
            'train': phoenix_train_loader,
            'train-small': phonix_train_small_loader,
            'val': phoenix_val_loader,
            'test': phoenix_test_loader
        },
        n_epochs = 10,
        save_folder = "Models/Phoenix14T/" + str(i)
    )
    
    del model
    del extra_layers
    del opt
    torch.cuda.empty_cache()
    
    all_histories.append(histories)

In [None]:
evaluate(t5, phoenix_train_loader, bleu)

In [None]:
t_encoder = TransformerEncoder()

In [None]:
t_encoder(test_batch[0]).shape

In [None]:
t_decoder = TransformerDecoder()

In [None]:
t_decoder.decoder_layer

In [None]:
german_tokenizer.pad_token = german_tokenizer.eos_token

In [None]:
t_encoder.to(device)
t_decoder.to(device)
optimizer = torch.optim.AdamW([{'params': t_encoder.parameters()}, {'params': t_decoder.parameters()}], lr=1e-3)
out_to_gloss_probs = nn.Sequential(
    nn.Linear(frame_encoding_size, vocab_size+1),
    nn.LogSoftmax(dim = -1)
).to(device)
ctc_loss = nn.CTCLoss(zero_infinity=True)

In [None]:
# batch[0].shape

In [None]:
optimizer.zero_grad()
inputs = batch[0].to(device)
out = bert(inputs_embeds=inputs).last_hidden_state
print(out.shape)
gloss_probs = out_to_gloss_probs(out.view(-1, frame_encoding_size)).view(vid_batch_size, slrt_input_len, vocab_size+1).permute(1, 0, 2)
print(gloss_probs.shape)

In [None]:
pred = torch.argmax(gloss_probs, dim = -1)
print(german_tokenizer.decode(pred[:, 0]))

In [None]:
batch[2]

In [None]:
print(batch[2][0])
target = german_tokenizer(list(batch[2]), return_tensors='pt', padding='max_length', max_length=max_vid_len)
target.input_ids, target.input_ids.shape, torch.sum(target.attention_mask, dim = -1)

In [None]:
input_lengths = torch.full(size=(vid_batch_size,), fill_value=slrt_input_len, dtype=torch.long)
loss = ctc_loss(gloss_probs, target.input_ids, input_lengths, torch.sum(target.attention_mask, dim = -1))
loss

In [None]:
make_dot(loss)

In [None]:
loss.backward()

In [None]:
optimizer.step()

In [None]:
len(phoenix_train_loader)

In [None]:
test_batch[2]

In [None]:
bos = german_tokenizer.bos_token
bos

In [None]:
decoder_input = german_tokenizer([bos + t for t in test_batch[2]], return_tensors='pt', padding='max_length', max_length=200)
decoder_input

In [None]:
decoder_input.attention_mask

In [None]:
test_batch[0].shape

In [None]:
look_ahead_mask = torch.triu(torch.ones(200, 200) * float('-inf'), diagonal=1)
look_ahead_mask

In [None]:
softmax = nn.Softmax(dim=-1)
ce_loss = nn.CrossEntropyLoss()

In [None]:
memory = test_batch[0].to(device)
tgt = decoder_input.input_ids.to(device)
look_ahead_mask = look_ahead_mask.to(device)
outs = t_decoder(memory, tgt, look_ahead_mask)
probs = softmax(outs)
targets = german_tokenizer(list(test_batch[2]), return_tensors='pt', padding='max_length', max_length=200).input_ids.to(device)
loss = ce_loss(outs.view(-1, vocab_size), targets.view(-1))
print(loss)

In [None]:
german_tokenizer.batch_decode(torch.argmax(probs, dim=-1))[0], test_batch[2][0]

In [None]:
loss.backward()
optimizer.step()

In [None]:
train_progress = tqdm(phoenix_train)
for i, batch in enumerate(phoenix_train_loader):
    optimizer.zero_grad()
    inputs = batch[0].to(device)
    out = t_encoder(inputs)
    #print(out.shape)
    gloss_probs = out_to_gloss_probs(out.view(-1, frame_encoding_size)).view(vid_batch_size, slrt_input_len, vocab_size+1).permute(1, 0, 2)
    #print(gloss_probs.shape)
    target = german_tokenizer(list(batch[1]), return_tensors='pt', padding='max_length', max_length=100).to(device)
    if 0 in target:
        print("BREAKING")
        break
    input_lengths = torch.full(size=(vid_batch_size,), fill_value=slrt_input_len, dtype=torch.long)
    recog_loss = ctc_loss(gloss_probs, target.input_ids, input_lengths, torch.sum(target.attention_mask, dim = -1))
    
    
    print(i, recog_loss)
    if i % 100 == 0:
        pred = torch.argmax(gloss_probs, dim = -1)
        print(german_tokenizer.decode(pred[:, 0]))
        print(batch[2][0])
    train_progress.update(len(batch[0]))
    loss.backward()
    optimizer.step()

In [None]:
german_tokenizer.encode("NORDWEST WIND SCHWACH MAESSIG SUED SCHWACH BEWEGEN")

In [None]:
german_tokenizer.decode([50, 19187, 59, 12044, 336])

In [None]:
pred.shape

In [None]:
F.softmax(gloss_probs)

In [None]:
optimizer.zero_grad()
out = bert(vid_encoder_input.cpu()).last_hidden_state
out.shape

In [None]:
gloss_probs = out_to_gloss_probs(out.view(-1, frame_encoding_size)).view(vid_batch_size, slrt_input_len, vocab_size).permute(1, 0, 2)
gloss_probs.shape

In [None]:
batch[2]['gloss'][0]

In [None]:
german_tokenizer([batch[2]['gloss'][0]])

In [None]:
target = german_tokenizer(batch[2]['gloss'], return_tensors='pt', padding='max_length', max_length=max_vid_len)
target.input_ids.shape

In [None]:
input_lengths = torch.full(size=(vid_batch_size,), fill_value=slrt_input_len, dtype=torch.long)

In [None]:
target.attention_mask.shape

In [None]:
input_lengths, batch[0], gloss_probs.shape, target.input_ids.shape

In [None]:
loss = ctc_loss(gloss_probs, target.input_ids, input_lengths, torch.sum(target.attention_mask, dim = -1))
loss

In [None]:
loss.backward()

In [None]:
nn.utils.clip_grad_norm_(bert.parameters(), max_norm=max_gradient_norm, error_if_nonfinite=True)

In [None]:
optimizer.step()

In [None]:
# Target are to be padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes

# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# input = gloss_probs
#print(torch.sum(input, dim=-1))
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)

input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss

In [None]:
input_lengths, target_lengths, input.shape, target.shape

In [None]:
class VideoEncoder(nn.Module):
    def __init__(self, vid_dim, d_model, n_head, n_layers):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_head, batch_first=True)
        self.encoder = TransformerEncoder(encoder_layer, n_layers)
        
    def forward(self, src, src_mask):
        src = src * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.encoder(src, src_mask)
        return output