In [None]:
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from transformers import BertTokenizer, AutoTokenizer
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from glob import glob
from tqdm import tqdm
from collections import Counter
import re
import os
import re
from glob import glob
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt
from transformers import get_cosine_schedule_with_warmup

In [None]:
# Device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# This function parses VTT files and extracts the start time, end time, and text.
def parse_vtt(vtt_path, video_id):
    # Function to convert time string to seconds
    def time_str_to_seconds(time_str):
        h, m, s = time_str.split(":")
        s, ms = s.split(".")
        return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000

    # Read the VTT file
    with open(vtt_path, "r", encoding="utf-8") as f:
        vtt_text = f.read()

    # Remove WebVTT header 
    blocks = re.split(r'\n\n+', vtt_text.strip())
    # Create a list to store entries
    entries = []

    # Iterate through each block and extract the start time, end time, and text
    for block in blocks:
        # Strip whitespace and split by lines
        lines = block.strip().splitlines()
        # If the block has at least 2 lines and the first line contains "-->", it's a valid entry
        if len(lines) >= 2 and "-->" in lines[0]:
            # Extract start time, end time, text and video_id to distinguish overlapping times
            start, end = lines[0].split(" --> ")
            text = " ".join(lines[1:]).strip()
            entries.append({
                "start": time_str_to_seconds(start.strip()),
                "end": time_str_to_seconds(end.strip()),
                "text": text,
                "video_id": video_id 
            })

    # Return the list of entries
    return entries

In [None]:
# This function extracts the video ID from the file path.
def extract_video_id(path):
    return os.path.basename(path).split("_keypoints")[0]

In [None]:
# path to the keypoints and subtitles files and create a list of files
keypoints_path = "<keypoints_path_here>"
keypoints_files = glob(os.path.join(keypoints_path, "*_keypoints.pth"))
subtitles_path = "<subtitles_path_here>"
subtitles_files = glob(os.path.join(subtitles_path, "*.vtt"))

print(len(glob(keypoints_path)))

# Create a list to store keypoints and subtitles
keypoints = []
subtitles = []

# Iterate through the keypoints files and load them
for k in keypoints_files:
  base_name = os.path.basename(k).replace("_keypoints.pth", "")
  print(k)
  temp_keypoints = torch.load(k)
  # Append the video ID and keypoints to the list
  keypoints.append((extract_video_id(k), temp_keypoints))

counter  = 0

# Iterate through the subtitles files and parse them
for s in subtitles_files:
    print(s)
    base_name = os.path.basename(s).replace(".vtt", "")
    parsed_subs = parse_vtt(s, base_name)
    # Append the video ID and parsed subtitles to the list
    subtitles.extend(parsed_subs)

print("finished")

In [None]:
# path to the keypoints and subtitles files and create a list of files
keypoints_path = "<keypoints_path_here>"
keypoints_files = glob(os.path.join(keypoints_path, "*_keypoints.pth"))
subtitles_path = "<subtitles_path_here>"
subtitles_files = glob(os.path.join(subtitles_path, "*.vtt"))

print(len(glob(keypoints_path)))

# Create a list to store test keypoints and subtitles
keypoints_test = []
subtitles_test = []

# Iterate through the keypoints files and load them
for k in keypoints_files:
  print(k)
  base_name = os.path.basename(k).replace("_keypoints.pth", "")
  temp_keypoints = torch.load(k)
  # Append the video ID and keypoints to the list
  keypoints_test.append((extract_video_id(k), temp_keypoints))

# Iterate through the subtitles files and parse them
for s in subtitles_files:
    print(s)
    base_name = os.path.basename(s).replace(".vtt", "")
    parsed_subs = parse_vtt(s, base_name)
    # Append the video ID and parsed subtitles to the list
    subtitles_test.extend(parsed_subs)

print("finished")

In [None]:
# Function to convert subtitle times to seconds
def time_to_float(time_str):

    # Split the time string into hours, minutes, and seconds
    hours, minutes, seconds = time_str.split(':')

    # Convert to float
    hours = float(hours)
    minutes = float(minutes)
    seconds = float(seconds)

    # Convert to total seconds
    total_seconds = hours * 3600 + minutes * 60 + seconds

    # Return the total seconds as a float
    return total_seconds

In [None]:
# Dataset class for keypoints and subtitles
class ManualKeypointsAndSubtitlesDataset(Dataset):
    def __init__(self, keypoints_data, subtitle_entries, fps=25, tokenizer=None, max_length=80, num_joints=25):
        # Set the fps
        self.fps = fps
        # Set the max length for tokenization
        self.max_length = max_length
        # Set the number of joints for keypoints
        self.num_joints = num_joints
        # Set the device for PyTorch
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Set the tokenizer for text processing
        self.tokenizer = tokenizer or AutoTokenizer.from_pretrained("bert-base-uncased")
        # Create a dictionary for keypoints and subtitles sample data
        self.samples = self.build_samples(keypoints_data, subtitle_entries)

    # Build samples from keypoints, subtitles, and emotions
    def build_samples(self, keypoints_dict, subtitles):
        # Create a dictionary for samples
        samples = []

        # Iterate through the subtitles and extract keypoints
        for sub in subtitles:
            # Extract video ID, start time, and end time
            video_id = sub["video_id"]
            start_frame = int(sub["start"] * self.fps)
            end_frame = int(sub["end"] * self.fps)

            # Check if the video ID exists in keypoints data
            if video_id not in keypoints_dict:
                continue

            # Extract keypoints for the video ID
            video_kps = keypoints_dict[video_id]

            # Check if the start and end frames are within the bounds of the keypoints
            if end_frame > len(video_kps):
                continue

            # Extract the keypoints and emotions for the specified time range
            keypoints_seq = video_kps[start_frame:end_frame]

            # Check if the sequences are empty
            if len(keypoints_seq) == 0:
                continue

            # Process keypoints
            processed_kps = []
            # Iterate through the keypoints 
            for frame in keypoints_seq:
                # Flatten the keypoints and pad with zeros if necessary
                frame_tensor = torch.zeros(self.num_joints * 3)
                # Check if the frame is empty 
                if len(frame) > 0:
                    # Get the first person in the frame
                    person = frame[0]
                    # Flatten the keypoints and pad with zeros if necessary
                    flat_kps = [coord for part in person for joint in part for coord in joint]
                    flat_kps = flat_kps[:self.num_joints * 3] + [0] * max(0, self.num_joints * 3 - len(flat_kps))
                    # Convert to tensor
                    frame_tensor = torch.tensor(flat_kps, dtype=torch.float32)
                # Append the tensor to the list
                processed_kps.append(frame_tensor)

            # Pad the sequences to the maximum length
            text = sub["text"]
            tokenized = self.tokenizer(text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")["input_ids"].squeeze()
            # Append the sample to the list
            samples.append((torch.stack(processed_kps), tokenized))

        # Return the list of samples
        return samples

    # Get the length of the dataset
    def __len__(self):
        return len(self.samples)

    # Get a sample from the dataset
    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
# Define the tokenizer for text processing
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Create the dictionary for keypoints
keypoints_dict = dict(keypoints) 

# Create the dataset for training
manual_train_dataset = ManualKeypointsAndSubtitlesDataset(keypoints_dict, subtitles, fps=25, tokenizer=tokenizer)

# Print the dataset size and the shape of the first sample
print("Dataset size:", len(manual_train_dataset))
k, s = manual_train_dataset[0]
print("Keypoints Shape:", k.shape) 
print("Subtitles Shape:", s.shape)

In [None]:
# Define the tokenizer for text processing
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Create the dictionary for test keypoints 
keypoints_dict_test = dict(keypoints_test)

# Create the dataset for testing
manual_test_dataset = ManualKeypointsAndSubtitlesDataset(keypoints_dict_test, subtitles_test, fps=25, tokenizer=tokenizer)

# Print the dataset size and the shape of the first sample
print("Dataset size:", len(manual_test_dataset))
k, s = manual_test_dataset[0]
print("Keypoints Shape:", k.shape) 
print("Subtitles Shape:", s.shape)

In [None]:
# Function to collate and pad the samples into batches
def manual_collate_fn(batch):
    k, s = zip(*batch)

    # Pad the keypoints, emotions, and subtitles sequences
    keypoints_padded = pad_sequence(k, batch_first=True, padding_value=0.0)  # [B, T, 75]
    subtitles_padded = pad_sequence(s, batch_first=True, padding_value=0)    # [B, L]

    # Return the padded sequences
    return keypoints_padded, subtitles_padded

In [None]:
# Create the DataLoader for training
manual_train_loader = DataLoader(manual_train_dataset, batch_size=128, shuffle=True, collate_fn=manual_collate_fn)

In [None]:
# Create the DataLoader for testing
manual_test_loader = DataLoader(manual_test_dataset, batch_size=128, shuffle=True, collate_fn=manual_collate_fn)

In [None]:
# Define Positional Encoding class
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        # Create a positional encoding matrix
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # Compute the positional encoding using sine and cosine functions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        # Add positional encoding to the input tensor
        # Check if the input tensor exceeds the maximum length
        if x.size(1) > self.pe.size(1):
            raise ValueError(f"Sequence length {x.size(1)} exceeds max positional encoding length {self.pe.size(1)}")
        return x + self.pe[:, :x.size(1), :].to(x.device)


In [None]:
# Define the Encoder-Decoder Transformer model for manual keypoints and subtitles
class ManualEncoderDecoderTransformer(nn.Module):
    def __init__(self, keypoints_dim=75, d_model=384, num_heads=6, num_layers=4, ff_dim=512, max_len=1024, vocab_size=30522, pad_idx=0):
        super().__init__()
        # Keypoints projection layer to transform keypoints to the model dimension
        self.keypoints_proj = nn.Linear(keypoints_dim, d_model)
        # Dropout layer for regularization
        self.input_dropout = nn.Dropout(0.1)

        # Positional encoding layers for encoder and decoder
        self.encoder_pe = PositionalEncoding(d_model, max_len)
        self.decoder_pe = PositionalEncoding(d_model, max_len)

        # Transformer encoder layer
        encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads, ff_dim, dropout=0.1)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

        # Transformer decoder layer
        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, ff_dim, dropout=0.1)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        # Embedding layer for text input
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.output_fc = nn.Linear(d_model, vocab_size)
        # Weights tying for the output layer
        self.output_fc.weight = self.embedding.weight

    def forward(self, keypoints, input_ids, tgt_mask=None):
        # Mask for the source sequence (keypoints)
        src_mask = keypoints.abs().sum(dim=-1) == 0
        # Mask for the target sequence (subtitles)
        tgt_pad_mask = input_ids == 0

        # Apply the keypoints projection and positional encoding
        x = self.keypoints_proj(keypoints)
        x = self.encoder_pe(self.input_dropout(x)).permute(1, 0, 2)
        memory = self.encoder(x, src_key_padding_mask=src_mask)

        # Apply the embedding and positional encoding for the target sequence
        tgt = self.embedding(input_ids)
        tgt = self.decoder_pe(self.input_dropout(tgt)).permute(1, 0, 2)
        out = self.decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_mask)

        # Apply the output fully connected layer to get the final output
        return self.output_fc(out.permute(1, 0, 2))

In [None]:
# Initialize the model and move it to the device
# Only run this when this is the first run
manual_model = ManualEncoderDecoderTransformer().to(device)

In [None]:
# Only run this when loading an existing model
# Recreate the model and load the weights
manual_model = ManualEncoderDecoderTransformer().to(device)

# Load the model weights
manual_model.load_state_dict(torch.load("<path_to_model>", map_location=device))

In [None]:
# Beam search decoding function
def sample_decode_beam(model, keypoints, tokenizer, beam_width=3, max_len=300, eos_token_id=102):
    # Set the model to evaluation mode
    model.eval()
    # Define the device
    device = keypoints.device
    # Get generated sequences
    generated = [(torch.tensor([tokenizer.cls_token_id], device=device), 0.0)]

    # Iterate for the maximum length of the sequence
    for _ in range(max_len):
        # Create a list to store all candidates
        all_candidates = []
        # Iterate through the generated sequences
        for seq, score in generated:
            # Check if the last token is the end-of-sequence token
            if seq[-1].item() == eos_token_id:
                # If sp add it to the candidates
                all_candidates.append((seq, score))
                continue

            # Unsqueeze the sequence to add batch dimension
            input_ids = seq.unsqueeze(0)
            # Generate the target mask
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(input_ids.size(1)).to(device)

            with torch.no_grad():
                # Forward pass through the model
                logits = model(keypoints.unsqueeze(0), input_ids, tgt_mask=tgt_mask)

            # Get the logits for the last token
            next_token_logits = logits[0, -1, :]
            # Apply softmax to get probabilities
            probs = torch.softmax(next_token_logits, dim=-1)
            # Get the top k tokens and their probabilities
            top_probs, top_indices = probs.topk(beam_width)

            # Iterate through the beam width
            for i in range(beam_width):
                # Create a candidate sequence
                candidate = torch.cat([seq, top_indices[i].unsqueeze(0)])
                # Append the candidate and its score to the list
                all_candidates.append((candidate, score - torch.log(top_probs[i] + 1e-12)))

        # Sort the candidates by score
        ordered = sorted(all_candidates, key=lambda x: x[1])
        # Keep only the top k candidates
        generated = ordered[:beam_width]

    # Get the best candidate
    return tokenizer.decode(generated[0][0], skip_special_tokens=True)

In [None]:
# Function to train and validate the model
def train_validate_model(
    model, train_loader, val_loader, optimizer, criterion, scheduler,
    device, tokenizer, num_epochs=300, eos_token_id=102, pad_token_id=0,
):
    # Initialize the values used in training
    best_val_loss = float('inf')
    patience, patience_counter = 15, 0
    train_accuracies, val_accuracies = [], []
    train_losses, val_losses = [], []

    # Iterate through the number of epochs
    for epoch in range(num_epochs):
        # Set the model to training mode
        model.train()
        # Initialize the values for training
        total_train_loss, total_train_correct, total_train_tokens = 0, 0, 0

        # Iterate through the training data
        for keypoints, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
          # Skip long sequences
          if keypoints.size(1) > 1024 or targets.size(1) > 80:
            print("Skipping long sequence")
            continue
          else:
            # Move the keypoints and targets to the device
            keypoints, targets = keypoints.to(device), targets.to(device)

            # Create the decoder input and target sequences
            decoder_input = targets[:, :-1]
            decoder_target = targets[:, 1:]
            # Create the target mask
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(decoder_input.size(1)).to(device)

            # Zero the gradients
            optimizer.zero_grad()
            # Forward pass through the model
            logits = model(keypoints, decoder_input, tgt_mask=tgt_mask)
            # Compute the loss
            loss = criterion(logits.view(-1, logits.size(-1)), decoder_target.reshape(-1))
            
            # Backward pass and optimization
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            # Compute the accuracy
            preds = logits.argmax(dim=-1)
            mask = (decoder_target != pad_token_id) & (decoder_target != eos_token_id)
            correct = ((preds == decoder_target) & mask).sum().item()
            total = mask.sum().item()

            # Update the total values
            total_train_correct += correct
            total_train_tokens += total
            total_train_loss += loss.item()

        # Compute the average accuracy and loss for training
        train_acc = total_train_correct / total_train_tokens * 100
        avg_train_loss = total_train_loss / len(train_loader)

        # Append the training accuracy and loss to the lists
        train_accuracies.append(train_acc)
        train_losses.append(avg_train_loss)

        # Validation
        # Set the model to evaluation mode
        model.eval()
        # Initialize the values for validation
        total_val_loss, total_val_correct, total_val_tokens = 0, 0, 0
        with torch.no_grad():
            # Iterate through the validation data
            for keypoints, targets in tqdm(val_loader, desc="Validating"):
              # Skip long sequences
              if keypoints.size(1) > 1024 or targets.size(1) > 80:
                print("Skipping long sequence")
                continue
              else:
                # Move the keypoints and targets to the device
                keypoints, targets = keypoints.to(device), targets.to(device)

                # Create the decoder input and target sequences
                decoder_input = targets[:, :-1]
                decoder_target = targets[:, 1:]
                # Create the target mask
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(decoder_input.size(1)).to(device)

                # Forward pass through the model
                logits = model(keypoints, decoder_input, tgt_mask=tgt_mask)
                # Compute the loss
                loss = criterion(logits.view(-1, logits.size(-1)), decoder_target.reshape(-1))

                # Compute the accuracy
                preds = logits.argmax(dim=-1)
                mask = decoder_target != pad_token_id
                correct = ((preds == decoder_target) & mask).sum().item()
                total = mask.sum().item()

                # Update the total values
                total_val_correct += correct
                total_val_tokens += total
                total_val_loss += loss.item()

        # Compute the average accuracy and loss for validation
        val_acc = total_val_correct / total_val_tokens * 100
        avg_val_loss = total_val_loss / len(val_loader)

        # Append the validation accuracy and loss to the lists
        val_accuracies.append(val_acc)
        val_losses.append(avg_val_loss)

        # Print the summary for the epoch
        print(f"Epoch {epoch+1} Summary:\nTrain Loss: {avg_train_loss:.4f}, Acc: {train_acc:.2f}% | "
              f"Val Loss: {avg_val_loss:.4f}, Acc: {val_acc:.2f}%")

        # Create sample
        keypoints_sample, _ = next(iter(val_loader))

        for keypoints_sample, targets_sample in val_loader:
          if keypoints_sample.size(1) <= 1024:
            sample_text = sample_decode_beam(model, keypoints_sample[0].to(device), tokenizer)
            # Print the sample text
            print("Sample decoded:", sample_text)
            break

        # Check for early stopping
        if avg_val_loss < best_val_loss:
            # Save the best model
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_manual_model_batch1.pt")
        else:
            # Increment the patience counter
            patience_counter += 1
            print(f"No improvement in validation loss. {patience_counter} / {patience}")
            # Check if patience is exceeded
            if patience_counter >= patience:
                # Stop training
                print("Early stopping triggered.")
                break

    # Plotting
    epochs_range = range(1, len(train_losses) + 1)

    # Accuracy Plot
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
    plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
    plt.xlabel("Epoch"), plt.ylabel("Accuracy (%)")
    plt.title("Accuracy Over Epochs"), plt.legend(), plt.grid(True)

    # Loss Plot
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses, label='Validation Loss')
    plt.xlabel("Epoch"), plt.ylabel("Loss")
    plt.title("Loss Over Epochs"), plt.legend(), plt.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
# Define the criterion for the loss function
criterion = nn.CrossEntropyLoss(ignore_index=0) 
# Define the optimizer for the model
optimizer = torch.optim.AdamW(manual_model.parameters(), lr=1e-5, weight_decay=0.01)

# Define the learning rate scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=len(manual_train_loader) * 500 # num_epochs
)

In [None]:
# Clear the cache and collect garbage
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
# Train and validate the model
train_validate_model(manual_model, manual_train_loader, manual_test_loader, optimizer, criterion, scheduler, device, tokenizer, num_epochs=500)

Example output:

Epoch 1 Summary:
Train Loss: 82.1442, Acc: 0.35% | Val Loss: 71.8930, Acc: 0.66%
Sample decoded: buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons buttons
Epoch 2/500 [Train]: 100%|██████████| 39/39 [00:09<00:00,  4.11it/s]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.12it/s]
Epoch 2 Summary:
Train Loss: 60.1969, Acc: 0.82% | Val Loss: 48.1628, Acc: 2.61%
Sample decoded: bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian bolivian
Epoch 3/500 [Train]: 100%|██████████| 39/39 [00:09<00:00,  3.95it/s]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.01it/s]
Epoch 3 Summary:
Train Loss: 43.3728, Acc: 1.14% | Val Loss: 35.9727, Acc: 2.90%
Sample decoded: ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '
Epoch 4/500 [Train]: 100%|██████████| 39/39 [00:09<00:00,  4.06it/s]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.42it/s]
Epoch 4 Summary:
Train Loss: 35.4826, Acc: 1.43% | Val Loss: 30.8406, Acc: 3.98%
Sample decoded: ' ' ' ' ' ' ' ' ' ' ' ' ' '
Epoch 5/500 [Train]: 100%|██████████| 39/39 [00:09<00:00,  4.03it/s]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.09it/s]