In [None]:
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from transformers import BertTokenizer, AutoTokenizer
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from glob import glob
from tqdm import tqdm
from collections import Counter
import re
import os
import re
from glob import glob
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt

In [None]:
# Device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def parse_vtt(vtt_path, video_id):
    def time_str_to_seconds(time_str):
        h, m, s = time_str.split(":")
        s, ms = s.split(".")
        return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000

    with open(vtt_path, "r", encoding="utf-8") as f:
        vtt_text = f.read()

    blocks = re.split(r'\n\n+', vtt_text.strip())
    entries = []

    for block in blocks:
        lines = block.strip().splitlines()
        if len(lines) >= 2 and "-->" in lines[0]:
            start, end = lines[0].split(" --> ")
            text = " ".join(lines[1:]).strip()
            entries.append({
                "start": time_str_to_seconds(start.strip()),
                "end": time_str_to_seconds(end.strip()),
                "text": text,
                "video_id": video_id  # To distinguish overlapping times
            })

    return entries

In [None]:
def extract_video_id(path):
    return os.path.basename(path).split("_keypoints")[0]

In [None]:

keypoints_path = "/content/drive/MyDrive/model_training/normalised_keypoints_output_extracted/batch1/train"
keypoints_files = glob(os.path.join(keypoints_path, "*_keypoints.pth"))
subtitles_path = "/content/drive/MyDrive/model_training/subtitles/batch1/train"
subtitles_files = glob(os.path.join(subtitles_path, "*.vtt"))

print(len(glob(keypoints_path)))

keypoints = []
subtitles = []

for k in keypoints_files:
  base_name = os.path.basename(k).replace("_keypoints.pth", "")
  print(k)
  temp_keypoints = torch.load(k)
  keypoints.append((extract_video_id(k), temp_keypoints))

counter  = 0

for s in subtitles_files:
    print(s)
    base_name = os.path.basename(s).replace(".vtt", "")
    parsed_subs = parse_vtt(s, base_name)
    subtitles.extend(parsed_subs)

print("finished")

In [None]:
path = "/content/drive/MyDrive/model_training/i3d_features/train"
clip_gt_files = glob(os.path.join(path, "*_clip_gt_tensor.pt"))
clip_ix_files = glob(os.path.join(path, "*_clip_ix_tensor.pt"))
clip_preds_files = glob(os.path.join(path, "*_preds_tensor.pt"))

clip_gt_features = []
clip_ix_features = []
clip_preds_features = []

for gt in clip_gt_files:
  base_name = os.path.basename(gt).replace("_clip_gt_tensor.pt", "")
  print(gt)
  temp_gt = torch.load(gt)
  clip_gt_features.append((base_name, temp_gt))

for ix in clip_ix_files:
  base_name = os.path.basename(ix).replace("_clip_ix_tensor.pt", "")
  print(ix)
  temp_ix = torch.load(ix)
  clip_ix_features.append((base_name, temp_ix))

for pred in clip_preds_files:
  base_name = os.path.basename(pred).replace("_preds_tensor.pt", "")
  print(pred)
  temp_pred = torch.load(pred)
  clip_preds_features.append((base_name, temp_pred))

print("finished")

In [None]:
def get_swin_features(path):
    features_np = np.load(path)
    print(f"Loaded NumPy array with shape {features_np.shape} and dtype {features_np.dtype}")

    features_tensor = torch.from_numpy(features_np)
    features_tensor = features_tensor.float().to('cuda' if torch.cuda.is_available() else 'cpu')
    return features_tensor

In [None]:
path = "/content/drive/MyDrive/model_training/swin_features/train"
swin_files = glob(os.path.join(path, "*.npy"))

swin_features = []

for sw in swin_files:
  base_name = os.path.basename(sw).replace(".npy", "")
  print(sw)
  temp_sw = get_swin_features(sw)
  swin_features.append((base_name, temp_sw))

print("finished")

In [None]:
keypoints_path = "/content/drive/MyDrive/model_training/normalised_keypoints_output_extracted/batch1/test"
keypoints_files = glob(os.path.join(keypoints_path, "*_keypoints.pth"))
subtitles_path = "/content/drive/MyDrive/model_training/subtitles/batch1/test"
subtitles_files = glob(os.path.join(subtitles_path, "*.vtt"))

print(len(glob(keypoints_path)))

keypoints_test = []
subtitles_test = []

for k in keypoints_files:
  print(k)
  base_name = os.path.basename(k).replace("_keypoints.pth", "")
  temp_keypoints = torch.load(k)
  keypoints_test.append((extract_video_id(k), temp_keypoints))

for s in subtitles_files:
    print(s)
    base_name = os.path.basename(s).replace(".vtt", "")
    parsed_subs = parse_vtt(s, base_name)
    subtitles_test.extend(parsed_subs)

print("finished")

In [None]:
path = "/content/drive/MyDrive/model_training/i3d_features/test"
clip_gt_files = glob(os.path.join(path, "*_clip_gt_tensor.pt"))
clip_ix_files = glob(os.path.join(path, "*_clip_ix_tensor.pt"))
clip_preds_files = glob(os.path.join(path, "*_preds_tensor.pt"))

clip_gt_features_test = []
clip_ix_features_test = []
clip_preds_features_test = []

for gt in clip_gt_files:
  base_name = os.path.basename(gt).replace("_clip_gt_tensor.pt", "")
  print(gt)
  temp_gt = torch.load(gt)
  clip_gt_features_test.append((base_name, temp_gt))

for ix in clip_ix_files:
  base_name = os.path.basename(ix).replace("_clip_ix_tensor.pt", "")
  print(ix)
  temp_ix = torch.load(ix)
  clip_ix_features_test.append((base_name, temp_ix))

for pred in clip_preds_files:
  base_name = os.path.basename(pred).replace("_preds_tensor.pt", "")
  print(pred)
  temp_pred = torch.load(pred)
  clip_preds_features_test.append((base_name, temp_pred))

print("finished")

In [None]:
path = "/content/drive/MyDrive/model_training/swin_features/test"
swin_files = glob(os.path.join(path, "*.npy"))

swin_features_test = []

for sw in swin_files:
  base_name = os.path.basename(sw).replace(".npy", "")
  print(sw)
  temp_sw = get_swin_features(sw)
  swin_features_test.append((base_name, temp_sw))

print("finished")

In [None]:
all_feats = torch.cat([feats for _, feats in swin_features_test], dim=0) 

mean = all_feats.mean(dim=0)  # [D_sw]
std  = all_feats.std(dim=0)   # [D_sw]
eps  = 1e-6

normalized_swin_features_test = []
for vid, feats in swin_features_test:
    norm_feats = (feats - mean) / (std + eps)
    normalized_swin_features_test.append((vid, norm_feats))

normalized_swin_dict_test = dict(normalized_swin_features_test)

In [None]:
def time_to_float(time_str):

    # Split the time string into hours, minutes, and seconds
    hours, minutes, seconds = time_str.split(':')

    # Convert to float
    hours = float(hours)
    minutes = float(minutes)
    seconds = float(seconds)

    # Convert to total seconds
    total_seconds = hours * 3600 + minutes * 60 + seconds

    return total_seconds

In [None]:
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class AllFeaturesDataset(Dataset):
    def __init__(
        self,
        keypoints_data: dict,
        subtitle_entries: list,
        clip_gt_data: dict,
        clip_ix_data: dict,
        clip_preds_data: dict,
        swin_data: dict,
        fps: int = 25,
        tokenizer=None,
        max_length: int = 80,
        num_joints: int = 25,
    ):
        self.fps = fps
        self.max_length = max_length
        self.num_joints = num_joints
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = tokenizer or AutoTokenizer.from_pretrained("bert-base-uncased")

        # ensure fast lookup
        self.keypoints_data   = keypoints_data
        self.clip_gt_data     = clip_gt_data
        self.clip_ix_data     = clip_ix_data
        self.clip_preds_data  = clip_preds_data
        self.swin_data        = swin_data

        self.samples = self.build_samples(subtitle_entries)

    def build_samples(self, subtitles):
        samples = []

        for sub in subtitles:
            vid = sub["video_id"]
            if not (
                vid in self.keypoints_data
                and vid in self.clip_gt_data
                and vid in self.clip_ix_data
                and vid in self.clip_preds_data
                and vid in self.swin_data
            ):
                continue
            kp_seq  = self.keypoints_data[vid]    # list of N_frame frames
            gt_feat = self.clip_gt_data[vid]      # [N_clip × D_gt]
            ix_feat = self.clip_ix_data[vid]      # [N_clip × D_ix]
            pr_feat = self.clip_preds_data[vid]   # [N_clip × D_pr]
            sw_feat = self.swin_data[vid]         # [N_frame × D_sw]

            N_frame = len(kp_seq)
            N_clip  = gt_feat.size(0)

            start_f = int(sub["start"] * self.fps)
            end_f   = int(sub["end"]   * self.fps)
            if end_f <= start_f or end_f > N_frame:
                continue

            T = end_f - start_f
            frame_idxs = np.arange(start_f, end_f) 

            clip_idxs = np.floor(frame_idxs * (N_clip / N_frame)).astype(int)
            clip_idxs = np.clip(clip_idxs, 0, N_clip - 1) 

            gt_tensor = gt_feat[clip_idxs].float().unsqueeze(-1).to(self.device)
            ix_tensor = ix_feat[clip_idxs].float().unsqueeze(-1).to(self.device)
            pr_tensor = pr_feat[clip_idxs].float().to(self.device)

            processed_kps = []
            for fi in frame_idxs:
                frame = kp_seq[fi]
                if frame and len(frame[0]) > 0:
                    person = frame[0]
                    flat = [c for part in person for joint in part for c in joint]
                else:
                    flat = []
                flat = flat[: self.num_joints * 3] \
                       + [0] * max(0, self.num_joints * 3 - len(flat))
                processed_kps.append(torch.tensor(flat, dtype=torch.float32))
            kp_tensor = torch.stack(processed_kps).to(self.device) 

            sw_tensor = sw_feat[frame_idxs].to(self.device)

            tok = self.tokenizer(
                sub["text"],
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors="pt"
            )["input_ids"].squeeze(0).to(self.device)

            samples.append({
                "keypoints":  kp_tensor,
                "clip_gt":    gt_tensor,
                "clip_ix":    ix_tensor,
                "clip_preds": pr_tensor,
                "swin":       sw_tensor,
                "tokens":     tok,
            })

        return samples


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

keypoints_dict   = dict(keypoints) 
clip_gt_dict     = dict(clip_gt_features) 
clip_ix_dict     = dict(clip_ix_features)  
clip_preds_dict  = dict(clip_preds_features)
swin_dict        = dict(swin_features)

all_features_train_dataset = AllFeaturesDataset(keypoints_data = keypoints_dict, subtitle_entries = subtitles, clip_gt_data = clip_gt_dict, clip_ix_data = clip_ix_dict, clip_preds_data = clip_preds_dict, swin_data = normalized_swin_dict, fps = 25, tokenizer = tokenizer, max_length = 80, num_joints = 25)

print("Dataset size:", len(all_features_train_dataset))

sample = all_features_train_dataset[0]
kp_tensor    = sample["keypoints"]   # [T x (25*3)]
clip_gt      = sample["clip_gt"]     # [T x D_gt]
clip_ix      = sample["clip_ix"]     # [T x D_ix]
clip_preds   = sample["clip_preds"]  # [T x D_pr]
swin_feats   = sample["swin"]        # [T x D_sw]
tokens       = sample["tokens"]      # [max_length]

print("Keypoints shape:   ", kp_tensor.shape)
print("CLIP GT shape:     ", clip_gt.shape)
print("CLIP IX shape:     ", clip_ix.shape)
print("CLIP preds shape:  ", clip_preds.shape)
print("Swin features shape:", swin_feats.shape)
print("Token IDs shape:   ", tokens.shape)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

keypoints_test_dict   = dict(keypoints_test)
clip_gt_test_dict     = dict(clip_gt_features_test) 
clip_ix_test_dict     = dict(clip_ix_features_test) 
clip_preds_test_dict  = dict(clip_preds_features_test)
swin_test_dict        = dict(swin_features_test)

all_features_test_dataset = AllFeaturesDataset(keypoints_data = keypoints_test_dict, subtitle_entries = subtitles_test, clip_gt_data = clip_gt_test_dict, clip_ix_data = clip_ix_test_dict, clip_preds_data = clip_preds_test_dict, swin_data = normalized_swin_dict_test, fps = 25, tokenizer = tokenizer, max_length = 80, num_joints = 25)

print("Dataset size:", len(all_features_test_dataset))

sample = all_features_train_dataset[0]
kp_tensor    = sample["keypoints"]   # [T x (25*3)]
clip_gt      = sample["clip_gt"]     # [T x D_gt]
clip_ix      = sample["clip_ix"]     # [T x D_ix]
clip_preds   = sample["clip_preds"]  # [T x D_pr]
swin_feats   = sample["swin"]        # [T x D_sw]
tokens       = sample["tokens"]      # [max_length]

print("Keypoints shape:   ", kp_tensor.shape)
print("CLIP GT shape:     ", clip_gt.shape)
print("CLIP IX shape:     ", clip_ix.shape)
print("CLIP preds shape:  ", clip_preds.shape)
print("Swin features shape:", swin_feats.shape)
print("Token IDs shape:   ", tokens.shape)

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

def all_features_collate_fn(batch):
    keypoints =     [item["keypoints"]   for item in batch]  # list of [T_i × (J*3)]
    clip_gt =       [item["clip_gt"]     for item in batch]  # list of [T_i × D_gt]
    clip_ix =       [item["clip_ix"]     for item in batch]  # list of [T_i × D_ix]
    clip_preds =    [item["clip_preds"]  for item in batch]  # list of [T_i × D_pr]
    swin =          [item["swin"]        for item in batch]  # list of [T_i × D_sw]
    tokens =        [item["tokens"]      for item in batch]  # list of [L]

    keypoints_padded  = pad_sequence(keypoints,  batch_first=True, padding_value=0.0)
    clip_gt_padded    = pad_sequence(clip_gt,    batch_first=True, padding_value=0.0)
    clip_ix_padded    = pad_sequence(clip_ix,    batch_first=True, padding_value=0.0)
    clip_preds_padded = pad_sequence(clip_preds, batch_first=True, padding_value=0.0)
    swin_padded       = pad_sequence(swin,       batch_first=True, padding_value=0.0)

    tokens_stacked    = torch.stack(tokens, dim=0)

    return {
        "keypoints": keypoints_padded,   # FloatTensor[B × T_max × (J*3)]
        "clip_gt": clip_gt_padded,       # FloatTensor[B × T_max × D_gt]
        "clip_ix": clip_ix_padded,       # FloatTensor[B × T_max × D_ix]
        "clip_preds": clip_preds_padded, # FloatTensor[B × T_max × D_pr]
        "swin": swin_padded,             # FloatTensor[B × T_max × D_sw]
        "tokens": tokens_stacked,        # LongTensor [B × L]
    }


In [None]:
all_features_train_loader = DataLoader(all_features_train_dataset, batch_size=32, shuffle=True, collate_fn=all_features_collate_fn)

In [None]:
all_features_test_loader = DataLoader(all_features_test_dataset, batch_size=32, shuffle=True, collate_fn=all_features_collate_fn)

In [None]:
import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.pe = pe.unsqueeze(0)  # [1, max_len, d_model]

    def forward(self, x):
        # x: [B, T, d_model]
        T = x.size(1)
        if T > self.pe.size(1):
            raise ValueError(f"Sequence length {T} exceeds {self.pe.size(1)}")
        return x + self.pe[:, :T, :].to(x.device)

In [None]:
class AllFeaturesEncoderDecoderTransformer(nn.Module):
    def __init__(
        self,
        keypoints_dim: int = 75,
        clip_gt_dim:    int = 1,
        clip_ix_dim:    int = 1,
        clip_preds_dim: int = 1024,
        swin_dim:       int = 768,
        d_model:        int = 384,
        num_heads:      int = 6,
        num_layers:     int = 2,
        cross_layers:   int = 2, 
        ff_dim:         int = 512,
        max_len:        int = 1024,
        vocab_size:     int = 30522,
        pad_idx:        int = 0,
        dropout:        float = 0.1,
    ):
        super().__init__()
        self.kp_proj = nn.Linear(keypoints_dim, d_model)
        self.gt_proj = nn.Linear(clip_gt_dim,    d_model)
        self.ix_proj = nn.Linear(clip_ix_dim,    d_model)
        self.pr_proj = nn.Linear(clip_preds_dim, d_model)
        self.sw_proj = nn.Linear(swin_dim,       d_model)

        self.pe       = PositionalEncoding(d_model, max_len)
        self.input_dp = nn.Dropout(dropout)

        enc_layer   = nn.TransformerEncoderLayer(d_model, num_heads, ff_dim, dropout=dropout)
        self.kp_enc = nn.TransformerEncoder(enc_layer, num_layers, norm=nn.LayerNorm(d_model))
        self.gt_enc = nn.TransformerEncoder(enc_layer, num_layers, norm=nn.LayerNorm(d_model))
        self.ix_enc = nn.TransformerEncoder(enc_layer, num_layers, norm=nn.LayerNorm(d_model))
        self.pr_enc = nn.TransformerEncoder(enc_layer, num_layers, norm=nn.LayerNorm(d_model))
        self.sw_enc = nn.TransformerEncoder(enc_layer, num_layers, norm=nn.LayerNorm(d_model))

        self.cross_attn   = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.cross_ln     = nn.LayerNorm(d_model)
        self.cross_dp     = nn.Dropout(dropout)
        self.cross_layers = cross_layers

        dec_layer    = nn.TransformerDecoderLayer(d_model, num_heads, ff_dim, dropout=dropout)
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers, norm=nn.LayerNorm(d_model))

        self.embed  = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.out_fc = nn.Linear(d_model, vocab_size)
        self.out_fc.weight = self.embed.weight

    def forward(
        self,
        keypoints,    # [B, T, keypoints_dim]
        clip_gt,      # [B, T, 1]
        clip_ix,      # [B, T, 1]
        clip_preds,   # [B, T, 1024]
        swin,         # [B, T, 768]
        input_ids,    # [B, L]
        tgt_mask=None # [L, L] causal mask
    ):
        B, T, _ = keypoints.shape

        src_pad = (keypoints.abs().sum(-1) == 0)  # [B, T]

        kp = self.input_dp(self.pe(self.kp_proj(keypoints)))
        gt = self.input_dp(self.pe(self.gt_proj(clip_gt)))
        ix = self.input_dp(self.pe(self.ix_proj(clip_ix)))
        pr = self.input_dp(self.pe(self.pr_proj(clip_preds)))
        sw = self.input_dp(self.pe(self.sw_proj(swin)))

        kp = kp.permute(1,0,2)
        gt = gt.permute(1,0,2)
        ix = ix.permute(1,0,2)
        pr = pr.permute(1,0,2)
        sw = sw.permute(1,0,2)

        kp_mem = self.kp_enc(kp, src_key_padding_mask=src_pad)
        gt_mem = self.gt_enc(gt, src_key_padding_mask=src_pad)
        ix_mem = self.ix_enc(ix, src_key_padding_mask=src_pad)
        pr_mem = self.pr_enc(pr, src_key_padding_mask=src_pad)
        sw_mem = self.sw_enc(sw, src_key_padding_mask=src_pad)

        kv     = torch.cat([gt_mem, ix_mem, pr_mem, sw_mem], dim=0)

        kv_pad = src_pad.repeat(1, 4)

        q = kp_mem 
        for _ in range(self.cross_layers):
            attn_out, _ = self.cross_attn(

                query            = q,
                key              = kv,
                value            = kv,
                attn_mask        = None,  
                key_padding_mask = kv_pad,
                need_weights     = False
            )
            q = self.cross_ln(q + self.cross_dp(attn_out))

        memory = q 

        tgt     = self.embed(input_ids)               # [B, L, d_model]
        tgt     = self.input_dp(self.pe(tgt)).permute(1,0,2)  # → [L, B, d_model]
        tgt_pad = (input_ids == self.embed.padding_idx)      # [B, L]

        out = self.decoder(
            tgt,        # [L, B, d_model]
            memory,     # [T, B, d_model]
            tgt_mask,   # [L, L] causal mask for self-attn
            None,       # no 2D mask on cross-attn
            tgt_pad,    # [B, L]
            src_pad     # [B, T]
        )

        logits = self.out_fc(out).permute(1,0,2)  # → [B, L, V]

        logits = torch.nan_to_num(logits, nan=0.0, posinf=1e6, neginf=-1e6)
        return logits

In [None]:
all_features_model = AllFeaturesEncoderDecoderTransformer().to(device)

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

def sample_decode_beam(
    model,
    keypoints,     # [T, 75]
    clip_gt,       # [T, 1]
    clip_ix,       # [T, 1]
    clip_preds,    # [T, 1024]
    swin,          # [T, 768]
    tokenizer,
    beam_width=3,
    max_len=80,
    eos_token_id=102,
):
    model.eval()
    device = keypoints.device

    generated = [(torch.tensor([tokenizer.cls_token_id], device=device), 0.0)]

    for _ in range(max_len):
        all_candidates = []
        for seq, score in generated:
            if seq[-1].item() == eos_token_id:
                all_candidates.append((seq, score))
                continue

            input_ids = seq.unsqueeze(0)  # [1, L]
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(
                input_ids.size(1)
            ).to(device)

            with torch.no_grad():
                logits = model(
                    keypoints.unsqueeze(0),    # [1, T, 75]
                    clip_gt.unsqueeze(0),      # [1, T, 1]
                    clip_ix.unsqueeze(0),      # [1, T, 1]
                    clip_preds.unsqueeze(0),   # [1, T, 1024]
                    swin.unsqueeze(0),         # [1, T, 768]
                    input_ids,                 # [1, L]
                    tgt_mask=tgt_mask
                )  # → [1, L, V]

                next_logits = logits[0, -1, :]       # [V]
                probs = torch.softmax(next_logits, dim=-1)
                top_probs, top_idx = probs.topk(beam_width)

                for p, idx in zip(top_probs, top_idx):
                    new_seq = torch.cat([seq, idx.unsqueeze(0)])
                    new_score = score - torch.log(p + 1e-12)
                    all_candidates.append((new_seq, new_score))

        generated = sorted(all_candidates, key=lambda x: x[1])[:beam_width]

    return tokenizer.decode(generated[0][0], skip_special_tokens=True)

In [None]:
def train_validate_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    scheduler,
    device,
    tokenizer,
    num_epochs=300,
    eos_token_id=102,
    pad_token_id=0,
):
    best_val_loss = float("inf")
    patience, patience_counter = 15, 0

    train_accuracies, val_accuracies = [], []
    train_losses, val_losses = [], []

    for epoch in range(num_epochs):
        # ---- TRAIN ----
        model.train()
        tot_loss = tot_correct = tot_tokens = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            kp      = batch["keypoints"].to(device)    # [B×T×75]
            gt      = batch["clip_gt"].to(device)      # [B×T×D_gt]
            ix      = batch["clip_ix"].to(device)      # [B×T×D_ix]
            pr      = batch["clip_preds"].to(device)   # [B×T×D_pr]
            sw      = batch["swin"].to(device)         # [B×T×D_sw]
            targets = batch["tokens"].to(device)       # [B×L]

            if kp.size(1) > 1024 or targets.size(1) > 80:
                continue

            decoder_input  = targets[:, :-1]
            decoder_target = targets[:, 1:]
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(
                decoder_input.size(1)
            ).to(device)

            optimizer.zero_grad()
            logits = model(kp, gt, ix, pr, sw, decoder_input, tgt_mask=tgt_mask)

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                decoder_target.reshape(-1)
            )
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            preds = logits.argmax(-1)
            mask = (decoder_target != pad_token_id) & (decoder_target != eos_token_id)
            correct = ((preds == decoder_target) & mask).sum().item()
            total   = mask.sum().item()

            tot_correct += correct
            tot_tokens  += total
            tot_loss    += loss.item()

            scheduler.step()

        train_acc = tot_correct / tot_tokens * 100
        train_loss = tot_loss / len(train_loader)
        train_accuracies.append(train_acc)
        train_losses.append(train_loss)

        model.eval()
        val_loss = val_correct = val_tokens = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating"):
                kp      = batch["keypoints"].to(device)
                gt      = batch["clip_gt"].to(device)
                ix      = batch["clip_ix"].to(device)
                pr      = batch["clip_preds"].to(device)
                sw      = batch["swin"].to(device)
                targets = batch["tokens"].to(device)

                if kp.size(1) > 1024 or targets.size(1) > 80:
                    continue

                decoder_input  = targets[:, :-1]
                decoder_target = targets[:, 1:]
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(
                    decoder_input.size(1)
                ).to(device)

                logits = model(kp, gt, ix, pr, sw, decoder_input, tgt_mask=tgt_mask)
                loss = criterion(
                    logits.reshape(-1, logits.size(-1)),
                    decoder_target.reshape(-1)
                )

                preds = logits.argmax(-1)
                mask = decoder_target != pad_token_id
                correct = ((preds == decoder_target) & mask).sum().item()
                total = mask.sum().item()

                val_correct += correct
                val_tokens  += total
                val_loss    += loss.item()

        val_acc = val_correct / val_tokens * 100
        val_loss = val_loss / len(val_loader)
        val_accuracies.append(val_acc)
        val_losses.append(val_loss)

        print(
            f"Epoch {epoch+1} | "
            f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}% | "
            f"Val   Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%"
        )

        for batch in val_loader:
            kp_s, gt_s, ix_s, pr_s, sw_s = (
                batch["keypoints"][0],
                batch["clip_gt"][0],
                batch["clip_ix"][0],
                batch["clip_preds"][0],
                batch["swin"][0],
            )
            sample_text = sample_decode_beam(
                model,
                kp_s.to(device),
                gt_s.to(device),
                ix_s.to(device),
                pr_s.to(device),
                sw_s.to(device),
                tokenizer,
                beam_width=3,
                max_len=80,
                eos_token_id=tokenizer.sep_token_id 
            )
            print("Sample:", sample_text)
            break

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_all_features__model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break

    # Plotting
    import matplotlib.pyplot as plt
    epochs_range = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
    plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
    plt.xlabel("Epoch"), plt.ylabel("Accuracy (%)")
    plt.title("Accuracy Over Epochs"), plt.legend(), plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses, label='Validation Loss')
    plt.xlabel("Epoch"), plt.ylabel("Loss")
    plt.title("Loss Over Epochs"), plt.legend(), plt.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, label_smoothing=0.1)  
optimizer = torch.optim.AdamW(all_features_model.parameters(), lr=1e-5, weight_decay=0.01)

from transformers import get_cosine_schedule_with_warmup

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=len(all_features_train_loader) * 500
)


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
train_validate_model(all_features_model, all_features_train_loader, all_features_test_loader, optimizer, criterion, scheduler, device, tokenizer, num_epochs=500)

Example Output:

Epoch 1 | Train Loss: 111.6891, Acc: 0.98% | Val   Loss: 47.6495, Acc: 2.07%
Sample: 
Epoch 2/500 [Train]: 100%|██████████| 143/143 [00:15<00:00,  9.29it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 25.24it/s]
Epoch 2 | Train Loss: 40.8139, Acc: 1.11% | Val   Loss: 32.9179, Acc: 2.56%
Sample: 
Epoch 3/500 [Train]: 100%|██████████| 143/143 [00:15<00:00,  9.22it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 26.13it/s]
Epoch 3 | Train Loss: 33.4913, Acc: 1.39% | Val   Loss: 28.9060, Acc: 3.62%
Sample: ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
Epoch 4/500 [Train]: 100%|██████████| 143/143 [00:15<00:00,  9.28it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 25.12it/s]
Epoch 4 | Train Loss: 30.4782, Acc: 1.46% | Val   Loss: 27.2522, Acc: 7.35%
Sample: .