In [None]:
import torch
import json
import torch.nn as nn
import os
from transformers import AutoTokenizer
import math
!pip install nltk rouge-score
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # Shape: [1, max_len, d_model]

    def forward(self, x):
        if x.size(1) > self.pe.size(1):
            raise ValueError(f"Sequence length {x.size(1)} exceeds max positional encoding length {self.pe.size(1)}")
        return x + self.pe[:, :x.size(1), :].to(x.device)

class ManualEncoderDecoderTransformer(nn.Module):
    def __init__(self, keypoints_dim=75, d_model=384, num_heads=6, num_layers=4, ff_dim=512, max_len=750, vocab_size=30522, pad_idx=0):
        super().__init__()
        self.keypoints_proj = nn.Linear(keypoints_dim, d_model)
        self.input_dropout = nn.Dropout(0.1)

        self.encoder_pe = PositionalEncoding(d_model, max_len)
        self.decoder_pe = PositionalEncoding(d_model, max_len)

        encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads, ff_dim, dropout=0.1)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, ff_dim, dropout=0.1)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.output_fc = nn.Linear(d_model, vocab_size)
        self.output_fc.weight = self.embedding.weight  # Weight tying

    def forward(self, keypoints, input_ids, tgt_mask=None):
        src_mask = keypoints.abs().sum(dim=-1) == 0
        tgt_pad_mask = input_ids == 0

        x = self.keypoints_proj(keypoints)
        x = self.encoder_pe(self.input_dropout(x)).permute(1, 0, 2)
        memory = self.encoder(x, src_key_padding_mask=src_mask)

        tgt = self.embedding(input_ids)
        tgt = self.decoder_pe(self.input_dropout(tgt)).permute(1, 0, 2)
        out = self.decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_mask)

        return self.output_fc(out.permute(1, 0, 2))

        return self.fc_out(output.permute(1, 0, 2))

In [None]:
# path to the directory containing the manual model
model_path = "<path_to_model>"

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Recreate the model
model = ManualEncoderDecoderTransformer().to(device)

# Load the model weights
model.load_state_dict(torch.load(model_path, map_location=device))

# Set model to eval mode
model.eval()

# Load tokenizer 
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")




In [None]:
# Load JSON files
def load_from_json(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data


In [None]:
# Load the test data
test_data = load_from_json("<path_to_test_data>")

In [None]:
# greedy decoding function
def greedy_decode(model, keypoints_tensor, tokenizer, max_len=80, start_token_id=101, end_token_id=102):
    # Set the model to evaluation mode
    model.eval()
    # Get the device from the model parameters
    device = next(model.parameters()).device

    # Move the tensors to the same device as the model
    # unsqueeze the tensors to add a batch dimension
    keypoints_tensor = keypoints_tensor.unsqueeze(0).to(device)  # [1, T, 75]
    # Get the start token ID from the tokenizer
    generated = torch.tensor([[start_token_id]], dtype=torch.long, device=device)  # [1, 1]

    # Iterate for the maximum length of the sequence
    for step in range(max_len):
        # Generate the target mask for the decoder
        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(generated.size(1)).to(device)

        with torch.no_grad():
            # Forward pass through the model
            logits = model(keypoints_tensor, generated, tgt_mask=tgt_mask)
            # Get the logits for the last token in the sequence
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            # Append the next token to the generated sequence
            generated = torch.cat([generated, next_token], dim=1)

        # Check if the end token is generated
        if next_token.item() == end_token_id:
            break

    # Decode the generated sequence to text
    return tokenizer.decode(generated[0], skip_special_tokens=True)


In [None]:
# Function to convert subtitle times to seconds
def time_to_float(time_str):

    # Split the time string into hours, minutes, and seconds
    hours, minutes, seconds = time_str.split(':')

    # Convert to float
    hours = float(hours)
    minutes = float(minutes)
    seconds = float(seconds)

    # Convert to total seconds
    total_seconds = hours * 3600 + minutes * 60 + seconds

    # Return the total seconds as a float
    return total_seconds

In [None]:
# Running Manual Tests
print("Running Manual Tests")


# Paths to the keypoints
test_keypoints_dir = "<path_to_test_keypoints>"
test_emotions_dir = "<path_to_test_emotions>"
# Load the test metadata
fps = 25
num_joints = 25
keypoints_dim = num_joints * 3

# Create a list to store the processed test samples
processed_test_samples = []
# Iterate through the test data
for data in test_data:
    # Extract the code and subtitle text from the data
    code = data["code"]
    subtitle_text = data["text"]

    # Build the path to the keypoints file
    keypoints_path = os.path.join(test_keypoints_dir, code + "_keypoints.pt")
    # Check if the keypoints file exists
    if not os.path.exists(keypoints_path):
        print(f"[Warning] Keypoints file not found for {code}")
        continue

    # Load the keypoints tensor
    full_keypoints = torch.load(keypoints_path)

    # Check if the keypoints tensor is empty
    if len(full_keypoints) == 0:
        print(f"[Warning] Skipping {code}: empty keypoints")
        continue
    
    # Process the keypoints tensor
    processed_kps = []
    # Iterate through the frames in the keypoints tensor
    for frame in full_keypoints:
        # Create a tensor for the keypoints in the current frame
        frame_tensor = torch.zeros(keypoints_dim)
        # Check if the frame is empty
        if len(frame) > 0:
            # Get the first person in the frame
            person = frame[0]
            # Flatten the keypoints for the first person
            flat_kps = [coord for part in person for joint in part for coord in joint]
            flat_kps = flat_kps[:keypoints_dim] + [0] * max(0, keypoints_dim - len(flat_kps))
            # Create a tensor for the flattened keypoints
            frame_tensor = torch.tensor(flat_kps[:keypoints_dim], dtype=torch.float32)
        # Append the frame tensor to the list of processed keypoints
        processed_kps.append(frame_tensor)

    # Convert the list of processed keypoints to a tensor
    keypoints_tensor = torch.stack(processed_kps)  # [T, 75]
    # Add the processed sample to the list
    processed_test_samples.append((keypoints_tensor, subtitle_text, code))

# Initialize the values for BLEU and ROUGE scores
smooth_fn = SmoothingFunction().method1
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
all_bleu_scores = []
all_rouge_scores = []

print("\nRunning Manual Tests")
# Iterate through the processed test samples
for input_tensor, expected_text, code in processed_test_samples:
    print(f"\nRunning Test: {code}")
    print("Expected Text:", expected_text)
    print("Input Shape:", input_tensor.shape)

    # Get the predicted text using greedy decoding
    result = greedy_decode(model, input_tensor.to(device), tokenizer)
    print("Result:", result)

    # BLEU
    ref_tokens = [expected_text.split()]
    gen_tokens = result.split()
    bleu = sentence_bleu(ref_tokens, gen_tokens, smoothing_function=smooth_fn)
    print("BLEU: " , bleu)
    all_bleu_scores.append(bleu)
        
    # ROUGE
    rouge = scorer.score(expected_text, result)
    all_rouge_scores.append(rouge)
    print("ROUGE:", rouge)

# Get the average BLEU and ROUGE scores
avg_bleu = sum(all_bleu_scores) / len(all_bleu_scores)
avg_rouge1 = sum(r['rouge1'].fmeasure for r in all_rouge_scores) / len(all_rouge_scores)
avg_rougeL = sum(r['rougeL'].fmeasure for r in all_rouge_scores) / len(all_rouge_scores)

Running Manual Tests

Running Manual Tests

Running Test: 001
Expected Text: Help me.
Input Shape: torch.Size([25, 75])
Result: i ' m you.
BLEU:  0
ROUGE: {'rouge1': Score(precision=0.0, recall=0.0, fmeasure=0.0), 'rougeL': Score(precision=0.0, recall=0.0, fmeasure=0.0)}

Running Test: 002
Expected Text: It would start to concern me a little bit, looking at five years, looking at retiring.
Input Shape: torch.Size([124, 75])




Result: i ' m the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the.
BLEU:  0
ROUGE: {'rouge1': Score(precision=0.0, recall=0.0, fmeasure=0.0), 'rougeL': Score(precision=0.0, recall=0.0, fmeasure=0.0)}

Running Test: 003
Expected Text: Oh, God, I hate these Land Cruisers.
Input Shape: torch.Size([75, 75])
Result: i ' m a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a
BLEU:  0
ROUGE: {'rouge1': Score(precision=0.012658227848101266, recall=0.14285714285714285, fmeasure=0.023255813953488372), 'rougeL': Score(precision=0.012658227848101266, recall=0.14285714285714285, fmeasure=0.023255813953488372)}

Running Test: 004
Expected Text: Hey, that's lovely!
Input Shape: torch.Size([38, 75])
Result: i ' m you ' re.
BLEU:  0
ROUGE: {'rouge1': Score(precision=0.0, recall=0.0, fmeasure=0.0), 'ro

In [13]:
print("avg_bleu: ", avg_bleu)
print("avg_rouge1: ", avg_rouge1)
print("avg_rougeL: ", avg_rougeL)

avg_bleu:  0.005196222123689377
avg_rouge1:  0.0865284578839326
avg_rougeL:  0.07827895943536672
