In [1]:
# STEP 1️⃣: Mount Drive and Install Libraries

from google.colab import drive
drive.mount('/content/drive')

# Install required libraries
!pip install -q nltk torchmetrics

# Download punkt tokenizer for BLEU (only once)
import nltk
nltk.download('punkt')


Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m66.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m36.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m49.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [2]:
# Step 2: Load Vocabulary and Dataset
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader

class VideoCaptionDataset(Dataset):
    def __init__(self, feature_dir, json_path, vocab, max_caption_len=45):
        self.feature_dir = feature_dir
        self.vocab = vocab
        self.max_len = max_caption_len

        with open(json_path, 'r') as f:
            data = json.load(f)

        self.video_ids = [v['video_id'] for v in data['videos']]
        self.captions_map = {}
        for s in data['sentences']:
            self.captions_map.setdefault(s['video_id'], []).append(s['caption'])

        # Filter only those video_ids that have feature file
        self.video_ids = [
            vid for vid in self.video_ids if os.path.exists(os.path.join(feature_dir, f"{vid}.npy"))
        ]

    def __len__(self):
        return len(self.video_ids)

    def __getitem__(self, idx):
        vid = self.video_ids[idx]
        feat_path = os.path.join(self.feature_dir, f"{vid}.npy")
        video_feat = torch.tensor(np.load(feat_path), dtype=torch.float32)

        caption = np.random.choice(self.captions_map[vid])
        tokens = [self.vocab['<SOS>']] + [
            self.vocab.get(w, self.vocab['<UNK>']) for w in caption.lower().split()
        ] + [self.vocab['<EOS>']]

        if len(tokens) < self.max_len:
            tokens += [self.vocab['<PAD>']] * (self.max_len - len(tokens))
        else:
            tokens = tokens[:self.max_len]

        caption_tensor = torch.tensor(tokens, dtype=torch.long)
        return video_feat, caption_tensor


In [3]:
#Load Vocab and Create Dataloaders
# Load vocab
vocab_path = '/content/drive/MyDrive/msvd_split/vocab.json'
with open(vocab_path, 'r') as f:
    vocab = json.load(f)

# Dataset paths
train_feature_dir = '/content/drive/MyDrive/msvd_split/train/features'
train_json = '/content/drive/MyDrive/msvd_split/train/train_captions.json'

val_feature_dir = '/content/drive/MyDrive/msvd_split/val/features'
val_json = '/content/drive/MyDrive/msvd_split/val/val_captions.json'

# Create datasets
train_dataset = VideoCaptionDataset(train_feature_dir, train_json, vocab)
val_dataset = VideoCaptionDataset(val_feature_dir, val_json, vocab)

# Create loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"✅ Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")


✅ Train size: 1251, Val size: 91


In [4]:
#Step 3: Define Transformer Decoder Model
# Full Code Block for the Model
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: [B, L, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Transformer-based Decoder
class VideoTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_ff=2048,
                 max_len=45, dropout=0.1, input_feat_dim=2048):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.d_model = d_model

        # Embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, dropout, max_len)

        # Project video feature dim to model dim
        self.vid_fc = nn.Linear(input_feat_dim, d_model)

        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_ff, dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        # Output projection
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, video_feats, captions):
        """
        video_feats: [B, T, 2048]
        captions: [B, L] (tokenized)
        """
        B, T, _ = video_feats.shape
        L = captions.shape[1]

        # Encode video
        memory = self.vid_fc(video_feats)  # [B, T, d_model]

        # Embed target captions
        tgt_emb = self.embedding(captions) * math.sqrt(self.d_model)
        tgt_emb = self.pos_enc(tgt_emb)  # [B, L, d_model]

        # Generate subsequent mask for causal attention
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(captions.device)

        out = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask)
        logits = self.output(out)  # [B, L, vocab_size]
        return F.log_softmax(logits, dim=-1)


In [55]:
import os

# Create folder if not exists
folder = '/content/drive/MyDrive/transformer_model'
os.makedirs(folder, exist_ok=True)

# Save model file
path = os.path.join(folder, 'video_transformer.py')
with open(path, 'w') as f:
    f.write(model_code)

print(f"✅ Saved Transformer model class to: {path}")


✅ Saved Transformer model class to: /content/drive/MyDrive/transformer_model/video_transformer.py


In [58]:
model_code = '''import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class VideoTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_ff=2048,
                 max_len=45, dropout=0.1, input_feat_dim=2048):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.d_model = d_model

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, dropout, max_len)
        self.vid_fc = nn.Linear(input_feat_dim, d_model)

        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_ff, dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, video_feats, captions):
        B, T, _ = video_feats.shape
        L = captions.shape[1]

        memory = self.vid_fc(video_feats)
        tgt_emb = self.embedding(captions) * math.sqrt(self.d_model)
        tgt_emb = self.pos_enc(tgt_emb)

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(captions.device)
        out = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask)
        logits = self.output(out)
        return F.log_softmax(logits, dim=-1)
'''

# 🔽 Save to Drive
path = '/content/drive/MyDrive/transformer_model/video_transformer.py'
with open(path, 'w') as f:
    f.write(model_code)

print(f"✅ Saved Transformer model class to: {path}")


✅ Saved Transformer model class to: /content/drive/MyDrive/transformer_model/video_transformer.py


In [59]:
import sys
sys.path.append('/content/drive/MyDrive/transformer_model')
from video_transformer import VideoTransformer


In [5]:
#Sample Instantiation
model = VideoTransformer(
    vocab_size=len(vocab),
    d_model=512,
    nhead=8,
    num_layers=3,
    dim_ff=2048,
    max_len=45,
    dropout=0.1,
    input_feat_dim=2048  # Since we're using ResNet/I3D features
).to("cuda" if torch.cuda.is_available() else "cpu")


In [6]:
#Step 4: Training Loop with Label Smoothing
# 1. Label Smoothing Loss Function
def label_smoothing_loss(pred, target, vocab_size, smoothing=0.1, pad_idx=0):
    """
    pred: [B*L, V]
    target: [B*L]
    """
    confidence = 1.0 - smoothing
    true_dist = torch.zeros_like(pred)
    true_dist.fill_(smoothing / (vocab_size - 2))
    true_dist.scatter_(1, target.unsqueeze(1), confidence)
    true_dist.masked_fill_((target == pad_idx).unsqueeze(1), 0)
    pred = F.log_softmax(pred, dim=1)
    loss = -torch.sum(true_dist * pred, dim=1)
    mask = (target != pad_idx).float()
    return torch.sum(loss * mask) / torch.sum(mask)


In [7]:
# 2. Full Training Loop
import torch.optim as optim
from tqdm import tqdm
import os

def train_transformer(model, train_loader, val_loader, vocab,
                      device, num_epochs=20, learning_rate=1e-4,
                      checkpoint_dir='/content/drive/MyDrive/transformer_checkpoints'):

    os.makedirs(checkpoint_dir, exist_ok=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    pad_idx = vocab['<PAD>']

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        print(f"\n🔁 Epoch {epoch}/{num_epochs}")

        for video_feats, captions in tqdm(train_loader, desc="Training"):
            video_feats, captions = video_feats.to(device), captions.to(device)

            inputs = captions[:, :-1]
            targets = captions[:, 1:]

            optimizer.zero_grad()
            output = model(video_feats, inputs)  # [B, L, V]
            loss = label_smoothing_loss(
                output.view(-1, output.size(-1)),
                targets.reshape(-1),
                vocab_size=len(vocab),
                smoothing=0.1,
                pad_idx=pad_idx
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"✅ Avg Training Loss: {avg_train_loss:.4f}")

        # === Validation ===
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for video_feats, captions in tqdm(val_loader, desc="Validating"):
                video_feats, captions = video_feats.to(device), captions.to(device)
                inputs = captions[:, :-1]
                targets = captions[:, 1:]
                output = model(video_feats, inputs)
                loss = label_smoothing_loss(
                    output.view(-1, output.size(-1)),
                    targets.reshape(-1),
                    vocab_size=len(vocab),
                    smoothing=0.1,
                    pad_idx=pad_idx
                )
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"🧪 Avg Validation Loss: {avg_val_loss:.4f}")
        scheduler.step()

        # Save checkpoint
        ckpt_path = f"{checkpoint_dir}/transformer_epoch_{epoch}.pt"
        torch.save(model.state_dict(), ckpt_path)
        print(f"💾 Saved model to {ckpt_path}")


In [9]:
# Start Training
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_transformer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    vocab=vocab,
    device=device,
    num_epochs=20,
    learning_rate=1e-4,
    checkpoint_dir='/content/drive/MyDrive/transformer_checkpoints'
)



🔁 Epoch 1/20


Training: 100%|██████████| 157/157 [09:19<00:00,  3.56s/it]


✅ Avg Training Loss: 5.9045


Validating: 100%|██████████| 12/12 [00:39<00:00,  3.26s/it]


🧪 Avg Validation Loss: 5.4615
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_1.pt

🔁 Epoch 2/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.13it/s]


✅ Avg Training Loss: 4.9812


Validating: 100%|██████████| 12/12 [00:00<00:00, 31.84it/s]


🧪 Avg Validation Loss: 4.8362
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_2.pt

🔁 Epoch 3/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.23it/s]


✅ Avg Training Loss: 4.7322


Validating: 100%|██████████| 12/12 [00:00<00:00, 32.74it/s]


🧪 Avg Validation Loss: 4.6030
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_3.pt

🔁 Epoch 4/20


Training: 100%|██████████| 157/157 [00:07<00:00, 20.06it/s]


✅ Avg Training Loss: 4.5940


Validating: 100%|██████████| 12/12 [00:00<00:00, 34.31it/s]


🧪 Avg Validation Loss: 4.5518
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_4.pt

🔁 Epoch 5/20


Training: 100%|██████████| 157/157 [00:08<00:00, 17.45it/s]


✅ Avg Training Loss: 4.4939


Validating: 100%|██████████| 12/12 [00:00<00:00, 33.84it/s]


🧪 Avg Validation Loss: 4.2958
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_5.pt

🔁 Epoch 6/20


Training: 100%|██████████| 157/157 [00:08<00:00, 19.46it/s]


✅ Avg Training Loss: 4.3178


Validating: 100%|██████████| 12/12 [00:00<00:00, 34.17it/s]


🧪 Avg Validation Loss: 4.2612
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_6.pt

🔁 Epoch 7/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.01it/s]


✅ Avg Training Loss: 4.3280


Validating: 100%|██████████| 12/12 [00:00<00:00, 31.59it/s]


🧪 Avg Validation Loss: 4.3848
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_7.pt

🔁 Epoch 8/20


Training: 100%|██████████| 157/157 [00:08<00:00, 19.24it/s]


✅ Avg Training Loss: 4.2571


Validating: 100%|██████████| 12/12 [00:00<00:00, 32.99it/s]


🧪 Avg Validation Loss: 4.4411
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_8.pt

🔁 Epoch 9/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.09it/s]


✅ Avg Training Loss: 4.2404


Validating: 100%|██████████| 12/12 [00:00<00:00, 32.74it/s]


🧪 Avg Validation Loss: 4.3478
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_9.pt

🔁 Epoch 10/20


Training: 100%|██████████| 157/157 [00:07<00:00, 20.09it/s]


✅ Avg Training Loss: 4.1585


Validating: 100%|██████████| 12/12 [00:00<00:00, 27.16it/s]


🧪 Avg Validation Loss: 4.2052
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_10.pt

🔁 Epoch 11/20


Training: 100%|██████████| 157/157 [00:08<00:00, 17.59it/s]


✅ Avg Training Loss: 4.1074


Validating: 100%|██████████| 12/12 [00:00<00:00, 34.68it/s]


🧪 Avg Validation Loss: 4.1991
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_11.pt

🔁 Epoch 12/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.15it/s]


✅ Avg Training Loss: 4.1664


Validating: 100%|██████████| 12/12 [00:00<00:00, 34.14it/s]


🧪 Avg Validation Loss: 4.3796
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_12.pt

🔁 Epoch 13/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.01it/s]


✅ Avg Training Loss: 4.1522


Validating: 100%|██████████| 12/12 [00:00<00:00, 33.83it/s]


🧪 Avg Validation Loss: 4.2760
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_13.pt

🔁 Epoch 14/20


Training: 100%|██████████| 157/157 [00:07<00:00, 19.92it/s]


✅ Avg Training Loss: 4.0673


Validating: 100%|██████████| 12/12 [00:00<00:00, 34.35it/s]


🧪 Avg Validation Loss: 4.4561
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_14.pt

🔁 Epoch 15/20


Training: 100%|██████████| 157/157 [00:08<00:00, 17.65it/s]


✅ Avg Training Loss: 4.0814


Validating: 100%|██████████| 12/12 [00:00<00:00, 34.44it/s]


🧪 Avg Validation Loss: 4.1520
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_15.pt

🔁 Epoch 16/20


Training: 100%|██████████| 157/157 [00:08<00:00, 19.11it/s]


✅ Avg Training Loss: 4.0723


Validating: 100%|██████████| 12/12 [00:00<00:00, 34.64it/s]


🧪 Avg Validation Loss: 4.1365
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_16.pt

🔁 Epoch 17/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.35it/s]


✅ Avg Training Loss: 4.1023


Validating: 100%|██████████| 12/12 [00:00<00:00, 28.99it/s]


🧪 Avg Validation Loss: 4.2014
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_17.pt

🔁 Epoch 18/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.84it/s]


✅ Avg Training Loss: 4.1038


Validating: 100%|██████████| 12/12 [00:00<00:00, 31.03it/s]


🧪 Avg Validation Loss: 4.2107
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_18.pt

🔁 Epoch 19/20


Training: 100%|██████████| 157/157 [00:08<00:00, 17.64it/s]


✅ Avg Training Loss: 3.9779


Validating: 100%|██████████| 12/12 [00:00<00:00, 31.93it/s]


🧪 Avg Validation Loss: 4.2921
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_19.pt

🔁 Epoch 20/20


Training: 100%|██████████| 157/157 [00:08<00:00, 18.89it/s]


✅ Avg Training Loss: 3.9967


Validating: 100%|██████████| 12/12 [00:00<00:00, 28.36it/s]


🧪 Avg Validation Loss: 4.0402
💾 Saved model to /content/drive/MyDrive/transformer_checkpoints/transformer_epoch_20.pt


In [51]:
# ✅ Replace XX with final epoch number if needed
save_path = "/content/drive/MyDrive/transformer_checkpoints/transformer_final.pt"

torch.save(model.state_dict(), save_path)
print(f"✅ Model saved at {save_path}")


✅ Model saved at /content/drive/MyDrive/transformer_checkpoints/transformer_final.pt


In [10]:
#Step 5: Inference with Greedy Decoding
#Greedy Decoding Function

def greedy_decode(model, video_feat, vocab, max_len=45, device='cuda'):
    """
    video_feat: [1, T, 2048] - single video
    """
    model.eval()
    sos_id = vocab['<SOS>']
    eos_id = vocab['<EOS>']
    inv_vocab = {v: k for k, v in vocab.items()}

    caption = [sos_id]
    video_feat = video_feat.to(device)

    with torch.no_grad():
        for _ in range(max_len):
            input_tensor = torch.tensor(caption).unsqueeze(0).to(device)  # [1, L]
            output = model(video_feat, input_tensor)  # [1, L, V]
            next_word = output[0, -1].argmax(dim=-1).item()
            caption.append(next_word)
            if next_word == eos_id:
                break

    decoded = [inv_vocab.get(tok, '<UNK>') for tok in caption[1:] if tok != eos_id]
    return ' '.join(decoded)


In [12]:
video_path = '/content/drive/MyDrive/msvd_split/test/features/_SNE2MYAotU_41_49.npy'
video_feat = torch.tensor(np.load(video_path)).unsqueeze(0).float()

# Generate caption
caption = greedy_decode(model, video_feat, vocab, max_len=45, device=device)
print("🎬 Generated Caption:", caption)


🎬 Generated Caption: a monkey is eating a small animal


In [13]:
import os
import numpy as np
import torch

feature_dir = '/content/drive/MyDrive/msvd_split/test/features'
filename = np.random.choice([f for f in os.listdir(feature_dir) if f.endswith('.npy')])
video_path = os.path.join(feature_dir, filename)

video_feat = torch.tensor(np.load(video_path)).unsqueeze(0).float().to(device)
caption = greedy_decode(model, video_feat, vocab, max_len=45, device=device)

print(f"🎬 File: {filename}")
print("📝 Caption:", caption)


🎬 File: NV6pq1W-I4g_7_16.npy
📝 Caption: a woman is putting a piece of a piece of a woman


In [14]:
#Step 6: BLEU Score Evaluation
#BLEU Evaluation Code
import os
import json
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def evaluate_bleu(model, feature_dir, caption_json, vocab, max_len=45, device='cuda'):
    with open(caption_json, 'r') as f:
        data = json.load(f)

    references = {}
    for entry in data['sentences']:
        vid = entry['video_id']
        caption = entry['caption'].lower().split()
        references.setdefault(vid, []).append(caption)

    inv_vocab = {v: k for k, v in vocab.items()}
    scores = []
    smooth_fn = SmoothingFunction().method4

    model.eval()
    for vid in tqdm(references.keys(), desc="Evaluating BLEU"):
        feat_path = os.path.join(feature_dir, f"{vid}.npy")
        if not os.path.exists(feat_path):
            continue

        video_feat = torch.tensor(np.load(feat_path)).unsqueeze(0).float().to(device)
        pred = greedy_decode(model, video_feat, vocab, max_len, device)
        pred_tokens = pred.lower().split()

        bleu = sentence_bleu(references[vid], pred_tokens, smoothing_function=smooth_fn)
        scores.append(bleu)

    avg_bleu = sum(scores) / len(scores)
    print(f"\n📊 Average BLEU-4 Score: {avg_bleu:.4f}")
    return avg_bleu


In [15]:
# Run BLEU Evaluation
feature_dir_test = '/content/drive/MyDrive/msvd_split/test/features'
json_test = '/content/drive/MyDrive/msvd_split/test/test_captions.json'

evaluate_bleu(
    model=model,
    feature_dir=feature_dir_test,
    caption_json=json_test,
    vocab=vocab,
    max_len=45,
    device=device
)


Evaluating BLEU: 100%|██████████| 78/78 [00:13<00:00,  5.90it/s]


📊 Average BLEU-4 Score: 0.4351





0.4350781727165359

In [16]:
# 1. Save Predictions to JSON
def save_predictions_to_json(model, feature_dir, output_path, vocab, max_len=45, device='cuda'):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    inv_vocab = {v: k for k, v in vocab.items()}
    predictions = {}

    model.eval()
    for filename in tqdm(os.listdir(feature_dir), desc="Saving Predictions"):
        if not filename.endswith('.npy'):
            continue
        vid = os.path.splitext(filename)[0]
        feat_path = os.path.join(feature_dir, filename)

        video_feat = torch.tensor(np.load(feat_path)).unsqueeze(0).float().to(device)
        pred = greedy_decode(model, video_feat, vocab, max_len, device)
        predictions[vid] = pred

    with open(output_path, 'w') as f:
        json.dump(predictions, f, indent=4)
    print(f"✅ Saved predictions to {output_path}")


In [17]:
# Usage
save_predictions_to_json(
    model=model,
    feature_dir='/content/drive/MyDrive/msvd_split/test/features',
    output_path='/content/drive/MyDrive/msvd_split/predictions_transformer.json',
    vocab=vocab,
    max_len=45,
    device=device
)


Saving Predictions: 100%|██████████| 548/548 [00:16<00:00, 32.91it/s]

✅ Saved predictions to /content/drive/MyDrive/msvd_split/predictions_transformer.json





In [28]:
#MINTOR
!pip install nltk
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')


[33mDEPRECATION: Loading egg at /usr/local/lib/python3.11/dist-packages/pycocoevalcap-1.2-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [29]:
# Reuse from previous step
refs = {
    "video123": ["a man is dancing", "a person is performing dance"],
    "video456": ["a dog is running", "an animal is sprinting"]
}

preds = {
    "video123": ["a man dancing"],
    "video456": ["a dog is running"]
}


In [30]:
from nltk.translate.meteor_score import meteor_score

def compute_meteor_score(refs, preds):
    scores = []
    for vid in preds:
        pred = preds[vid][0]
        references = refs.get(vid, [])
        if references:
            score = meteor_score(references, pred)
            scores.append(score)
    avg_score = sum(scores) / len(scores)
    print(f"🌟 METEOR Score: {avg_score:.4f}")
    return avg_score


In [34]:
import json

# Load ground truth
with open('/content/drive/MyDrive/msvd_split/test/test_captions.json', 'r') as f:
    gt = json.load(f)

refs = {}
for s in gt['sentences']:
    refs.setdefault(s['video_id'], []).append(s['caption'].lower())

# Load predictions
with open('/content/drive/MyDrive/msvd_split/predictions_transformer.json', 'r') as f:
    pred_json = json.load(f)

preds = {k: [v.lower()] for k, v in pred_json.items() if k in refs}

# Evaluate METEOR
compute_meteor_score(refs, preds)


🌟 METEOR Score: 0.6517


0.6516606132281912

In [32]:
from nltk.translate.meteor_score import meteor_score

def compute_meteor_score(refs, preds):
    scores = []
    for vid in preds:
        pred = preds[vid][0].split()  # 🔁 tokenized prediction
        references = [ref.split() for ref in refs.get(vid, [])]  # 🔁 tokenized references
        if references:
            score = meteor_score(references, pred)
            scores.append(score)
    avg_score = sum(scores) / len(scores)
    print(f"🌟 METEOR Score: {avg_score:.4f}")
    return avg_score


In [33]:
preds = {k: [v.lower()] for k, v in pred_json.items() if k in refs}


In [45]:
#give mp4 video and predict
video_path = '/content/04.mp4'  # 👈 replace with actual path


In [46]:
#Step 2: Extract frames using ffmpeg
import os
import subprocess

def extract_frames_from_mp4(video_path, out_dir, num_frames=40):
    os.makedirs(out_dir, exist_ok=True)
    # Extract 40 frames equally spaced
    cmd = f"ffmpeg -i {video_path} -vf fps=1 {out_dir}/%06d.jpg -hide_banner -loglevel error"
    subprocess.call(cmd, shell=True)

# Example usage
extract_frames_from_mp4(video_path, out_dir="/content/frames")


In [48]:
!pip install pretrainedmodels


[33mDEPRECATION: Loading egg at /usr/local/lib/python3.11/dist-packages/pycocoevalcap-1.2-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0mCollecting pretrainedmodels
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting munch (from pretrainedmodels)
  Downloading munch-4.0.0-py2.py3-none-any.whl.metadata (5.9 kB)
Downloading munch-4.0.0-py2.py3-none-any.whl (9.9 kB)
Building wheels for collected packages: pretrainedmodels
  Building wheel for pretrainedmodels (setup.py) ... [?25l[?25hdone
  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-py3-none-any.whl size=60945 sha256=aec4ccc27e149f2af5fb7cf517452c5fa0d23a920e2

In [49]:
#Step 3: Extract ResNet/CLIP Features
import torch
import numpy as np
import glob
from pretrainedmodels import resnet152
from pretrainedmodels.utils import LoadTransformImage

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet152(pretrained='imagenet')
model.last_linear = torch.nn.Identity()
model = model.to(device).eval()
load_img = LoadTransformImage(model)

def extract_features_from_frames(frame_dir, num_frames=40):
    images = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))
    sampled = np.linspace(0, len(images) - 1, num=num_frames).astype(int)
    images = [images[i] for i in sampled]

    feats = []
    for img_path in images:
        img = load_img(img_path).unsqueeze(0).to(device)
        with torch.no_grad():
            feat = model(img).cpu().squeeze()
        feats.append(feat)

    return torch.stack(feats).unsqueeze(0)  # [1, 40, 2048]

# Example
video_feat = extract_features_from_frames("/content/frames").float().to(device)


Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /root/.cache/torch/hub/checkpoints/resnet152-b121ed2d.pth
100%|██████████| 230M/230M [00:02<00:00, 94.5MB/s]


In [62]:
import sys
sys.path.append('/content/drive/MyDrive/transformer_model')

from video_transformer import VideoTransformer


In [63]:
# Load vocab beforehand
model = VideoTransformer(
    vocab_size=len(vocab),
    d_model=512,
    nhead=8,
    num_layers=3,
    dim_ff=2048,
    max_len=45,
    dropout=0.1,
    input_feat_dim=2048
).to(device)

# Load weights from final checkpoint
model.load_state_dict(torch.load('/content/drive/MyDrive/transformer_checkpoints/transformer_final.pt'))
model.eval()


RuntimeError: Error(s) in loading state_dict for VideoTransformer:
	Missing key(s) in state_dict: "embedding.weight", "pos_enc.pe", "vid_fc.weight", "vid_fc.bias", "decoder.layers.0.self_attn.in_proj_weight", "decoder.layers.0.self_attn.in_proj_bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.multihead_attn.in_proj_weight", "decoder.layers.0.multihead_attn.in_proj_bias", "decoder.layers.0.multihead_attn.out_proj.weight", "decoder.layers.0.multihead_attn.out_proj.bias", "decoder.layers.0.linear1.weight", "decoder.layers.0.linear1.bias", "decoder.layers.0.linear2.weight", "decoder.layers.0.linear2.bias", "decoder.layers.0.norm1.weight", "decoder.layers.0.norm1.bias", "decoder.layers.0.norm2.weight", "decoder.layers.0.norm2.bias", "decoder.layers.0.norm3.weight", "decoder.layers.0.norm3.bias", "decoder.layers.1.self_attn.in_proj_weight", "decoder.layers.1.self_attn.in_proj_bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.multihead_attn.in_proj_weight", "decoder.layers.1.multihead_attn.in_proj_bias", "decoder.layers.1.multihead_attn.out_proj.weight", "decoder.layers.1.multihead_attn.out_proj.bias", "decoder.layers.1.linear1.weight", "decoder.layers.1.linear1.bias", "decoder.layers.1.linear2.weight", "decoder.layers.1.linear2.bias", "decoder.layers.1.norm1.weight", "decoder.layers.1.norm1.bias", "decoder.layers.1.norm2.weight", "decoder.layers.1.norm2.bias", "decoder.layers.1.norm3.weight", "decoder.layers.1.norm3.bias", "decoder.layers.2.self_attn.in_proj_weight", "decoder.layers.2.self_attn.in_proj_bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.multihead_attn.in_proj_weight", "decoder.layers.2.multihead_attn.in_proj_bias", "decoder.layers.2.multihead_attn.out_proj.weight", "decoder.layers.2.multihead_attn.out_proj.bias", "decoder.layers.2.linear1.weight", "decoder.layers.2.linear1.bias", "decoder.layers.2.linear2.weight", "decoder.layers.2.linear2.bias", "decoder.layers.2.norm1.weight", "decoder.layers.2.norm1.bias", "decoder.layers.2.norm2.weight", "decoder.layers.2.norm2.bias", "decoder.layers.2.norm3.weight", "decoder.layers.2.norm3.bias", "output.weight", "output.bias". 
	Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.bn3.num_batches_tracked", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.0.downsample.1.num_batches_tracked", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn1.num_batches_tracked", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.bn2.num_batches_tracked", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.bn3.num_batches_tracked", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn1.num_batches_tracked", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.bn2.num_batches_tracked", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.bn3.num_batches_tracked", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn1.num_batches_tracked", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.bn2.num_batches_tracked", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.bn3.num_batches_tracked", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.0.downsample.1.num_batches_tracked", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn1.num_batches_tracked", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.bn2.num_batches_tracked", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.bn3.num_batches_tracked", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn1.num_batches_tracked", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.bn2.num_batches_tracked", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.bn3.num_batches_tracked", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn1.num_batches_tracked", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.bn2.num_batches_tracked", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.bn3.num_batches_tracked", "layer2.4.conv1.weight", "layer2.4.bn1.weight", "layer2.4.bn1.bias", "layer2.4.bn1.running_mean", "layer2.4.bn1.running_var", "layer2.4.bn1.num_batches_tracked", "layer2.4.conv2.weight", "layer2.4.bn2.weight", "layer2.4.bn2.bias", "layer2.4.bn2.running_mean", "layer2.4.bn2.running_var", "layer2.4.bn2.num_batches_tracked", "layer2.4.conv3.weight", "layer2.4.bn3.weight", "layer2.4.bn3.bias", "layer2.4.bn3.running_mean", "layer2.4.bn3.running_var", "layer2.4.bn3.num_batches_tracked", "layer2.5.conv1.weight", "layer2.5.bn1.weight", "layer2.5.bn1.bias", "layer2.5.bn1.running_mean", "layer2.5.bn1.running_var", "layer2.5.bn1.num_batches_tracked", "layer2.5.conv2.weight", "layer2.5.bn2.weight", "layer2.5.bn2.bias", "layer2.5.bn2.running_mean", "layer2.5.bn2.running_var", "layer2.5.bn2.num_batches_tracked", "layer2.5.conv3.weight", "layer2.5.bn3.weight", "layer2.5.bn3.bias", "layer2.5.bn3.running_mean", "layer2.5.bn3.running_var", "layer2.5.bn3.num_batches_tracked", "layer2.6.conv1.weight", "layer2.6.bn1.weight", "layer2.6.bn1.bias", "layer2.6.bn1.running_mean", "layer2.6.bn1.running_var", "layer2.6.bn1.num_batches_tracked", "layer2.6.conv2.weight", "layer2.6.bn2.weight", "layer2.6.bn2.bias", "layer2.6.bn2.running_mean", "layer2.6.bn2.running_var", "layer2.6.bn2.num_batches_tracked", "layer2.6.conv3.weight", "layer2.6.bn3.weight", "layer2.6.bn3.bias", "layer2.6.bn3.running_mean", "layer2.6.bn3.running_var", "layer2.6.bn3.num_batches_tracked", "layer2.7.conv1.weight", "layer2.7.bn1.weight", "layer2.7.bn1.bias", "layer2.7.bn1.running_mean", "layer2.7.bn1.running_var", "layer2.7.bn1.num_batches_tracked", "layer2.7.conv2.weight", "layer2.7.bn2.weight", "layer2.7.bn2.bias", "layer2.7.bn2.running_mean", "layer2.7.bn2.running_var", "layer2.7.bn2.num_batches_tracked", "layer2.7.conv3.weight", "layer2.7.bn3.weight", "layer2.7.bn3.bias", "layer2.7.bn3.running_mean", "layer2.7.bn3.running_var", "layer2.7.bn3.num_batches_tracked", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn1.num_batches_tracked", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.bn2.num_batches_tracked", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.bn3.num_batches_tracked", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.0.downsample.1.num_batches_tracked", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn1.num_batches_tracked", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.bn2.num_batches_tracked", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.bn3.num_batches_tracked", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn1.num_batches_tracked", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.bn2.num_batches_tracked", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.bn3.num_batches_tracked", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn1.num_batches_tracked", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.bn2.num_batches_tracked", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.bn3.num_batches_tracked", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn1.num_batches_tracked", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.bn2.num_batches_tracked", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.bn3.num_batches_tracked", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn1.num_batches_tracked", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.bn2.num_batches_tracked", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.bn3.num_batches_tracked", "layer3.6.conv1.weight", "layer3.6.bn1.weight", "layer3.6.bn1.bias", "layer3.6.bn1.running_mean", "layer3.6.bn1.running_var", "layer3.6.bn1.num_batches_tracked", "layer3.6.conv2.weight", "layer3.6.bn2.weight", "layer3.6.bn2.bias", "layer3.6.bn2.running_mean", "layer3.6.bn2.running_var", "layer3.6.bn2.num_batches_tracked", "layer3.6.conv3.weight", "layer3.6.bn3.weight", "layer3.6.bn3.bias", "layer3.6.bn3.running_mean", "layer3.6.bn3.running_var", "layer3.6.bn3.num_batches_tracked", "layer3.7.conv1.weight", "layer3.7.bn1.weight", "layer3.7.bn1.bias", "layer3.7.bn1.running_mean", "layer3.7.bn1.running_var", "layer3.7.bn1.num_batches_tracked", "layer3.7.conv2.weight", "layer3.7.bn2.weight", "layer3.7.bn2.bias", "layer3.7.bn2.running_mean", "layer3.7.bn2.running_var", "layer3.7.bn2.num_batches_tracked", "layer3.7.conv3.weight", "layer3.7.bn3.weight", "layer3.7.bn3.bias", "layer3.7.bn3.running_mean", "layer3.7.bn3.running_var", "layer3.7.bn3.num_batches_tracked", "layer3.8.conv1.weight", "layer3.8.bn1.weight", "layer3.8.bn1.bias", "layer3.8.bn1.running_mean", "layer3.8.bn1.running_var", "layer3.8.bn1.num_batches_tracked", "layer3.8.conv2.weight", "layer3.8.bn2.weight", "layer3.8.bn2.bias", "layer3.8.bn2.running_mean", "layer3.8.bn2.running_var", "layer3.8.bn2.num_batches_tracked", "layer3.8.conv3.weight", "layer3.8.bn3.weight", "layer3.8.bn3.bias", "layer3.8.bn3.running_mean", "layer3.8.bn3.running_var", "layer3.8.bn3.num_batches_tracked", "layer3.9.conv1.weight", "layer3.9.bn1.weight", "layer3.9.bn1.bias", "layer3.9.bn1.running_mean", "layer3.9.bn1.running_var", "layer3.9.bn1.num_batches_tracked", "layer3.9.conv2.weight", "layer3.9.bn2.weight", "layer3.9.bn2.bias", "layer3.9.bn2.running_mean", "layer3.9.bn2.running_var", "layer3.9.bn2.num_batches_tracked", "layer3.9.conv3.weight", "layer3.9.bn3.weight", "layer3.9.bn3.bias", "layer3.9.bn3.running_mean", "layer3.9.bn3.running_var", "layer3.9.bn3.num_batches_tracked", "layer3.10.conv1.weight", "layer3.10.bn1.weight", "layer3.10.bn1.bias", "layer3.10.bn1.running_mean", "layer3.10.bn1.running_var", "layer3.10.bn1.num_batches_tracked", "layer3.10.conv2.weight", "layer3.10.bn2.weight", "layer3.10.bn2.bias", "layer3.10.bn2.running_mean", "layer3.10.bn2.running_var", "layer3.10.bn2.num_batches_tracked", "layer3.10.conv3.weight", "layer3.10.bn3.weight", "layer3.10.bn3.bias", "layer3.10.bn3.running_mean", "layer3.10.bn3.running_var", "layer3.10.bn3.num_batches_tracked", "layer3.11.conv1.weight", "layer3.11.bn1.weight", "layer3.11.bn1.bias", "layer3.11.bn1.running_mean", "layer3.11.bn1.running_var", "layer3.11.bn1.num_batches_tracked", "layer3.11.conv2.weight", "layer3.11.bn2.weight", "layer3.11.bn2.bias", "layer3.11.bn2.running_mean", "layer3.11.bn2.running_var", "layer3.11.bn2.num_batches_tracked", "layer3.11.conv3.weight", "layer3.11.bn3.weight", "layer3.11.bn3.bias", "layer3.11.bn3.running_mean", "layer3.11.bn3.running_var", "layer3.11.bn3.num_batches_tracked", "layer3.12.conv1.weight", "layer3.12.bn1.weight", "layer3.12.bn1.bias", "layer3.12.bn1.running_mean", "layer3.12.bn1.running_var", "layer3.12.bn1.num_batches_tracked", "layer3.12.conv2.weight", "layer3.12.bn2.weight", "layer3.12.bn2.bias", "layer3.12.bn2.running_mean", "layer3.12.bn2.running_var", "layer3.12.bn2.num_batches_tracked", "layer3.12.conv3.weight", "layer3.12.bn3.weight", "layer3.12.bn3.bias", "layer3.12.bn3.running_mean", "layer3.12.bn3.running_var", "layer3.12.bn3.num_batches_tracked", "layer3.13.conv1.weight", "layer3.13.bn1.weight", "layer3.13.bn1.bias", "layer3.13.bn1.running_mean", "layer3.13.bn1.running_var", "layer3.13.bn1.num_batches_tracked", "layer3.13.conv2.weight", "layer3.13.bn2.weight", "layer3.13.bn2.bias", "layer3.13.bn2.running_mean", "layer3.13.bn2.running_var", "layer3.13.bn2.num_batches_tracked", "layer3.13.conv3.weight", "layer3.13.bn3.weight", "layer3.13.bn3.bias", "layer3.13.bn3.running_mean", "layer3.13.bn3.running_var", "layer3.13.bn3.num_batches_tracked", "layer3.14.conv1.weight", "layer3.14.bn1.weight", "layer3.14.bn1.bias", "layer3.14.bn1.running_mean", "layer3.14.bn1.running_var", "layer3.14.bn1.num_batches_tracked", "layer3.14.conv2.weight", "layer3.14.bn2.weight", "layer3.14.bn2.bias", "layer3.14.bn2.running_mean", "layer3.14.bn2.running_var", "layer3.14.bn2.num_batches_tracked", "layer3.14.conv3.weight", "layer3.14.bn3.weight", "layer3.14.bn3.bias", "layer3.14.bn3.running_mean", "layer3.14.bn3.running_var", "layer3.14.bn3.num_batches_tracked", "layer3.15.conv1.weight", "layer3.15.bn1.weight", "layer3.15.bn1.bias", "layer3.15.bn1.running_mean", "layer3.15.bn1.running_var", "layer3.15.bn1.num_batches_tracked", "layer3.15.conv2.weight", "layer3.15.bn2.weight", "layer3.15.bn2.bias", "layer3.15.bn2.running_mean", "layer3.15.bn2.running_var", "layer3.15.bn2.num_batches_tracked", "layer3.15.conv3.weight", "layer3.15.bn3.weight", "layer3.15.bn3.bias", "layer3.15.bn3.running_mean", "layer3.15.bn3.running_var", "layer3.15.bn3.num_batches_tracked", "layer3.16.conv1.weight", "layer3.16.bn1.weight", "layer3.16.bn1.bias", "layer3.16.bn1.running_mean", "layer3.16.bn1.running_var", "layer3.16.bn1.num_batches_tracked", "layer3.16.conv2.weight", "layer3.16.bn2.weight", "layer3.16.bn2.bias", "layer3.16.bn2.running_mean", "layer3.16.bn2.running_var", "layer3.16.bn2.num_batches_tracked", "layer3.16.conv3.weight", "layer3.16.bn3.weight", "layer3.16.bn3.bias", "layer3.16.bn3.running_mean", "layer3.16.bn3.running_var", "layer3.16.bn3.num_batches_tracked", "layer3.17.conv1.weight", "layer3.17.bn1.weight", "layer3.17.bn1.bias", "layer3.17.bn1.running_mean", "layer3.17.bn1.running_var", "layer3.17.bn1.num_batches_tracked", "layer3.17.conv2.weight", "layer3.17.bn2.weight", "layer3.17.bn2.bias", "layer3.17.bn2.running_mean", "layer3.17.bn2.running_var", "layer3.17.bn2.num_batches_tracked", "layer3.17.conv3.weight", "layer3.17.bn3.weight", "layer3.17.bn3.bias", "layer3.17.bn3.running_mean", "layer3.17.bn3.running_var", "layer3.17.bn3.num_batches_tracked", "layer3.18.conv1.weight", "layer3.18.bn1.weight", "layer3.18.bn1.bias", "layer3.18.bn1.running_mean", "layer3.18.bn1.running_var", "layer3.18.bn1.num_batches_tracked", "layer3.18.conv2.weight", "layer3.18.bn2.weight", "layer3.18.bn2.bias", "layer3.18.bn2.running_mean", "layer3.18.bn2.running_var", "layer3.18.bn2.num_batches_tracked", "layer3.18.conv3.weight", "layer3.18.bn3.weight", "layer3.18.bn3.bias", "layer3.18.bn3.running_mean", "layer3.18.bn3.running_var", "layer3.18.bn3.num_batches_tracked", "layer3.19.conv1.weight", "layer3.19.bn1.weight", "layer3.19.bn1.bias", "layer3.19.bn1.running_mean", "layer3.19.bn1.running_var", "layer3.19.bn1.num_batches_tracked", "layer3.19.conv2.weight", "layer3.19.bn2.weight", "layer3.19.bn2.bias", "layer3.19.bn2.running_mean", "layer3.19.bn2.running_var", "layer3.19.bn2.num_batches_tracked", "layer3.19.conv3.weight", "layer3.19.bn3.weight", "layer3.19.bn3.bias", "layer3.19.bn3.running_mean", "layer3.19.bn3.running_var", "layer3.19.bn3.num_batches_tracked", "layer3.20.conv1.weight", "layer3.20.bn1.weight", "layer3.20.bn1.bias", "layer3.20.bn1.running_mean", "layer3.20.bn1.running_var", "layer3.20.bn1.num_batches_tracked", "layer3.20.conv2.weight", "layer3.20.bn2.weight", "layer3.20.bn2.bias", "layer3.20.bn2.running_mean", "layer3.20.bn2.running_var", "layer3.20.bn2.num_batches_tracked", "layer3.20.conv3.weight", "layer3.20.bn3.weight", "layer3.20.bn3.bias", "layer3.20.bn3.running_mean", "layer3.20.bn3.running_var", "layer3.20.bn3.num_batches_tracked", "layer3.21.conv1.weight", "layer3.21.bn1.weight", "layer3.21.bn1.bias", "layer3.21.bn1.running_mean", "layer3.21.bn1.running_var", "layer3.21.bn1.num_batches_tracked", "layer3.21.conv2.weight", "layer3.21.bn2.weight", "layer3.21.bn2.bias", "layer3.21.bn2.running_mean", "layer3.21.bn2.running_var", "layer3.21.bn2.num_batches_tracked", "layer3.21.conv3.weight", "layer3.21.bn3.weight", "layer3.21.bn3.bias", "layer3.21.bn3.running_mean", "layer3.21.bn3.running_var", "layer3.21.bn3.num_batches_tracked", "layer3.22.conv1.weight", "layer3.22.bn1.weight", "layer3.22.bn1.bias", "layer3.22.bn1.running_mean", "layer3.22.bn1.running_var", "layer3.22.bn1.num_batches_tracked", "layer3.22.conv2.weight", "layer3.22.bn2.weight", "layer3.22.bn2.bias", "layer3.22.bn2.running_mean", "layer3.22.bn2.running_var", "layer3.22.bn2.num_batches_tracked", "layer3.22.conv3.weight", "layer3.22.bn3.weight", "layer3.22.bn3.bias", "layer3.22.bn3.running_mean", "layer3.22.bn3.running_var", "layer3.22.bn3.num_batches_tracked", "layer3.23.conv1.weight", "layer3.23.bn1.weight", "layer3.23.bn1.bias", "layer3.23.bn1.running_mean", "layer3.23.bn1.running_var", "layer3.23.bn1.num_batches_tracked", "layer3.23.conv2.weight", "layer3.23.bn2.weight", "layer3.23.bn2.bias", "layer3.23.bn2.running_mean", "layer3.23.bn2.running_var", "layer3.23.bn2.num_batches_tracked", "layer3.23.conv3.weight", "layer3.23.bn3.weight", "layer3.23.bn3.bias", "layer3.23.bn3.running_mean", "layer3.23.bn3.running_var", "layer3.23.bn3.num_batches_tracked", "layer3.24.conv1.weight", "layer3.24.bn1.weight", "layer3.24.bn1.bias", "layer3.24.bn1.running_mean", "layer3.24.bn1.running_var", "layer3.24.bn1.num_batches_tracked", "layer3.24.conv2.weight", "layer3.24.bn2.weight", "layer3.24.bn2.bias", "layer3.24.bn2.running_mean", "layer3.24.bn2.running_var", "layer3.24.bn2.num_batches_tracked", "layer3.24.conv3.weight", "layer3.24.bn3.weight", "layer3.24.bn3.bias", "layer3.24.bn3.running_mean", "layer3.24.bn3.running_var", "layer3.24.bn3.num_batches_tracked", "layer3.25.conv1.weight", "layer3.25.bn1.weight", "layer3.25.bn1.bias", "layer3.25.bn1.running_mean", "layer3.25.bn1.running_var", "layer3.25.bn1.num_batches_tracked", "layer3.25.conv2.weight", "layer3.25.bn2.weight", "layer3.25.bn2.bias", "layer3.25.bn2.running_mean", "layer3.25.bn2.running_var", "layer3.25.bn2.num_batches_tracked", "layer3.25.conv3.weight", "layer3.25.bn3.weight", "layer3.25.bn3.bias", "layer3.25.bn3.running_mean", "layer3.25.bn3.running_var", "layer3.25.bn3.num_batches_tracked", "layer3.26.conv1.weight", "layer3.26.bn1.weight", "layer3.26.bn1.bias", "layer3.26.bn1.running_mean", "layer3.26.bn1.running_var", "layer3.26.bn1.num_batches_tracked", "layer3.26.conv2.weight", "layer3.26.bn2.weight", "layer3.26.bn2.bias", "layer3.26.bn2.running_mean", "layer3.26.bn2.running_var", "layer3.26.bn2.num_batches_tracked", "layer3.26.conv3.weight", "layer3.26.bn3.weight", "layer3.26.bn3.bias", "layer3.26.bn3.running_mean", "layer3.26.bn3.running_var", "layer3.26.bn3.num_batches_tracked", "layer3.27.conv1.weight", "layer3.27.bn1.weight", "layer3.27.bn1.bias", "layer3.27.bn1.running_mean", "layer3.27.bn1.running_var", "layer3.27.bn1.num_batches_tracked", "layer3.27.conv2.weight", "layer3.27.bn2.weight", "layer3.27.bn2.bias", "layer3.27.bn2.running_mean", "layer3.27.bn2.running_var", "layer3.27.bn2.num_batches_tracked", "layer3.27.conv3.weight", "layer3.27.bn3.weight", "layer3.27.bn3.bias", "layer3.27.bn3.running_mean", "layer3.27.bn3.running_var", "layer3.27.bn3.num_batches_tracked", "layer3.28.conv1.weight", "layer3.28.bn1.weight", "layer3.28.bn1.bias", "layer3.28.bn1.running_mean", "layer3.28.bn1.running_var", "layer3.28.bn1.num_batches_tracked", "layer3.28.conv2.weight", "layer3.28.bn2.weight", "layer3.28.bn2.bias", "layer3.28.bn2.running_mean", "layer3.28.bn2.running_var", "layer3.28.bn2.num_batches_tracked", "layer3.28.conv3.weight", "layer3.28.bn3.weight", "layer3.28.bn3.bias", "layer3.28.bn3.running_mean", "layer3.28.bn3.running_var", "layer3.28.bn3.num_batches_tracked", "layer3.29.conv1.weight", "layer3.29.bn1.weight", "layer3.29.bn1.bias", "layer3.29.bn1.running_mean", "layer3.29.bn1.running_var", "layer3.29.bn1.num_batches_tracked", "layer3.29.conv2.weight", "layer3.29.bn2.weight", "layer3.29.bn2.bias", "layer3.29.bn2.running_mean", "layer3.29.bn2.running_var", "layer3.29.bn2.num_batches_tracked", "layer3.29.conv3.weight", "layer3.29.bn3.weight", "layer3.29.bn3.bias", "layer3.29.bn3.running_mean", "layer3.29.bn3.running_var", "layer3.29.bn3.num_batches_tracked", "layer3.30.conv1.weight", "layer3.30.bn1.weight", "layer3.30.bn1.bias", "layer3.30.bn1.running_mean", "layer3.30.bn1.running_var", "layer3.30.bn1.num_batches_tracked", "layer3.30.conv2.weight", "layer3.30.bn2.weight", "layer3.30.bn2.bias", "layer3.30.bn2.running_mean", "layer3.30.bn2.running_var", "layer3.30.bn2.num_batches_tracked", "layer3.30.conv3.weight", "layer3.30.bn3.weight", "layer3.30.bn3.bias", "layer3.30.bn3.running_mean", "layer3.30.bn3.running_var", "layer3.30.bn3.num_batches_tracked", "layer3.31.conv1.weight", "layer3.31.bn1.weight", "layer3.31.bn1.bias", "layer3.31.bn1.running_mean", "layer3.31.bn1.running_var", "layer3.31.bn1.num_batches_tracked", "layer3.31.conv2.weight", "layer3.31.bn2.weight", "layer3.31.bn2.bias", "layer3.31.bn2.running_mean", "layer3.31.bn2.running_var", "layer3.31.bn2.num_batches_tracked", "layer3.31.conv3.weight", "layer3.31.bn3.weight", "layer3.31.bn3.bias", "layer3.31.bn3.running_mean", "layer3.31.bn3.running_var", "layer3.31.bn3.num_batches_tracked", "layer3.32.conv1.weight", "layer3.32.bn1.weight", "layer3.32.bn1.bias", "layer3.32.bn1.running_mean", "layer3.32.bn1.running_var", "layer3.32.bn1.num_batches_tracked", "layer3.32.conv2.weight", "layer3.32.bn2.weight", "layer3.32.bn2.bias", "layer3.32.bn2.running_mean", "layer3.32.bn2.running_var", "layer3.32.bn2.num_batches_tracked", "layer3.32.conv3.weight", "layer3.32.bn3.weight", "layer3.32.bn3.bias", "layer3.32.bn3.running_mean", "layer3.32.bn3.running_var", "layer3.32.bn3.num_batches_tracked", "layer3.33.conv1.weight", "layer3.33.bn1.weight", "layer3.33.bn1.bias", "layer3.33.bn1.running_mean", "layer3.33.bn1.running_var", "layer3.33.bn1.num_batches_tracked", "layer3.33.conv2.weight", "layer3.33.bn2.weight", "layer3.33.bn2.bias", "layer3.33.bn2.running_mean", "layer3.33.bn2.running_var", "layer3.33.bn2.num_batches_tracked", "layer3.33.conv3.weight", "layer3.33.bn3.weight", "layer3.33.bn3.bias", "layer3.33.bn3.running_mean", "layer3.33.bn3.running_var", "layer3.33.bn3.num_batches_tracked", "layer3.34.conv1.weight", "layer3.34.bn1.weight", "layer3.34.bn1.bias", "layer3.34.bn1.running_mean", "layer3.34.bn1.running_var", "layer3.34.bn1.num_batches_tracked", "layer3.34.conv2.weight", "layer3.34.bn2.weight", "layer3.34.bn2.bias", "layer3.34.bn2.running_mean", "layer3.34.bn2.running_var", "layer3.34.bn2.num_batches_tracked", "layer3.34.conv3.weight", "layer3.34.bn3.weight", "layer3.34.bn3.bias", "layer3.34.bn3.running_mean", "layer3.34.bn3.running_var", "layer3.34.bn3.num_batches_tracked", "layer3.35.conv1.weight", "layer3.35.bn1.weight", "layer3.35.bn1.bias", "layer3.35.bn1.running_mean", "layer3.35.bn1.running_var", "layer3.35.bn1.num_batches_tracked", "layer3.35.conv2.weight", "layer3.35.bn2.weight", "layer3.35.bn2.bias", "layer3.35.bn2.running_mean", "layer3.35.bn2.running_var", "layer3.35.bn2.num_batches_tracked", "layer3.35.conv3.weight", "layer3.35.bn3.weight", "layer3.35.bn3.bias", "layer3.35.bn3.running_mean", "layer3.35.bn3.running_var", "layer3.35.bn3.num_batches_tracked", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.num_batches_tracked", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.num_batches_tracked", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.bn3.num_batches_tracked", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.num_batches_tracked", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.num_batches_tracked", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.num_batches_tracked", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.bn3.num_batches_tracked", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn1.num_batches_tracked", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.bn2.num_batches_tracked", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.bn3.num_batches_tracked". 

In [64]:
model = VideoTransformer(
    vocab_size=len(vocab),
    d_model=512,
    nhead=8,
    num_layers=3,        # ✅ Match this!
    dim_ff=2048,
    max_len=45,
    dropout=0.1,
    input_feat_dim=2048
).to(device)

# Now load the matching weights
model.load_state_dict(torch.load('/content/drive/MyDrive/transformer_checkpoints/transformer_final.pt'))
model.eval()


RuntimeError: Error(s) in loading state_dict for VideoTransformer:
	Missing key(s) in state_dict: "embedding.weight", "pos_enc.pe", "vid_fc.weight", "vid_fc.bias", "decoder.layers.0.self_attn.in_proj_weight", "decoder.layers.0.self_attn.in_proj_bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.multihead_attn.in_proj_weight", "decoder.layers.0.multihead_attn.in_proj_bias", "decoder.layers.0.multihead_attn.out_proj.weight", "decoder.layers.0.multihead_attn.out_proj.bias", "decoder.layers.0.linear1.weight", "decoder.layers.0.linear1.bias", "decoder.layers.0.linear2.weight", "decoder.layers.0.linear2.bias", "decoder.layers.0.norm1.weight", "decoder.layers.0.norm1.bias", "decoder.layers.0.norm2.weight", "decoder.layers.0.norm2.bias", "decoder.layers.0.norm3.weight", "decoder.layers.0.norm3.bias", "decoder.layers.1.self_attn.in_proj_weight", "decoder.layers.1.self_attn.in_proj_bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.multihead_attn.in_proj_weight", "decoder.layers.1.multihead_attn.in_proj_bias", "decoder.layers.1.multihead_attn.out_proj.weight", "decoder.layers.1.multihead_attn.out_proj.bias", "decoder.layers.1.linear1.weight", "decoder.layers.1.linear1.bias", "decoder.layers.1.linear2.weight", "decoder.layers.1.linear2.bias", "decoder.layers.1.norm1.weight", "decoder.layers.1.norm1.bias", "decoder.layers.1.norm2.weight", "decoder.layers.1.norm2.bias", "decoder.layers.1.norm3.weight", "decoder.layers.1.norm3.bias", "decoder.layers.2.self_attn.in_proj_weight", "decoder.layers.2.self_attn.in_proj_bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.multihead_attn.in_proj_weight", "decoder.layers.2.multihead_attn.in_proj_bias", "decoder.layers.2.multihead_attn.out_proj.weight", "decoder.layers.2.multihead_attn.out_proj.bias", "decoder.layers.2.linear1.weight", "decoder.layers.2.linear1.bias", "decoder.layers.2.linear2.weight", "decoder.layers.2.linear2.bias", "decoder.layers.2.norm1.weight", "decoder.layers.2.norm1.bias", "decoder.layers.2.norm2.weight", "decoder.layers.2.norm2.bias", "decoder.layers.2.norm3.weight", "decoder.layers.2.norm3.bias", "output.weight", "output.bias". 
	Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.bn3.num_batches_tracked", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.0.downsample.1.num_batches_tracked", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn1.num_batches_tracked", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.bn2.num_batches_tracked", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.bn3.num_batches_tracked", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn1.num_batches_tracked", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.bn2.num_batches_tracked", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.bn3.num_batches_tracked", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn1.num_batches_tracked", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.bn2.num_batches_tracked", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.bn3.num_batches_tracked", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.0.downsample.1.num_batches_tracked", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn1.num_batches_tracked", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.bn2.num_batches_tracked", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.bn3.num_batches_tracked", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn1.num_batches_tracked", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.bn2.num_batches_tracked", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.bn3.num_batches_tracked", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn1.num_batches_tracked", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.bn2.num_batches_tracked", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.bn3.num_batches_tracked", "layer2.4.conv1.weight", "layer2.4.bn1.weight", "layer2.4.bn1.bias", "layer2.4.bn1.running_mean", "layer2.4.bn1.running_var", "layer2.4.bn1.num_batches_tracked", "layer2.4.conv2.weight", "layer2.4.bn2.weight", "layer2.4.bn2.bias", "layer2.4.bn2.running_mean", "layer2.4.bn2.running_var", "layer2.4.bn2.num_batches_tracked", "layer2.4.conv3.weight", "layer2.4.bn3.weight", "layer2.4.bn3.bias", "layer2.4.bn3.running_mean", "layer2.4.bn3.running_var", "layer2.4.bn3.num_batches_tracked", "layer2.5.conv1.weight", "layer2.5.bn1.weight", "layer2.5.bn1.bias", "layer2.5.bn1.running_mean", "layer2.5.bn1.running_var", "layer2.5.bn1.num_batches_tracked", "layer2.5.conv2.weight", "layer2.5.bn2.weight", "layer2.5.bn2.bias", "layer2.5.bn2.running_mean", "layer2.5.bn2.running_var", "layer2.5.bn2.num_batches_tracked", "layer2.5.conv3.weight", "layer2.5.bn3.weight", "layer2.5.bn3.bias", "layer2.5.bn3.running_mean", "layer2.5.bn3.running_var", "layer2.5.bn3.num_batches_tracked", "layer2.6.conv1.weight", "layer2.6.bn1.weight", "layer2.6.bn1.bias", "layer2.6.bn1.running_mean", "layer2.6.bn1.running_var", "layer2.6.bn1.num_batches_tracked", "layer2.6.conv2.weight", "layer2.6.bn2.weight", "layer2.6.bn2.bias", "layer2.6.bn2.running_mean", "layer2.6.bn2.running_var", "layer2.6.bn2.num_batches_tracked", "layer2.6.conv3.weight", "layer2.6.bn3.weight", "layer2.6.bn3.bias", "layer2.6.bn3.running_mean", "layer2.6.bn3.running_var", "layer2.6.bn3.num_batches_tracked", "layer2.7.conv1.weight", "layer2.7.bn1.weight", "layer2.7.bn1.bias", "layer2.7.bn1.running_mean", "layer2.7.bn1.running_var", "layer2.7.bn1.num_batches_tracked", "layer2.7.conv2.weight", "layer2.7.bn2.weight", "layer2.7.bn2.bias", "layer2.7.bn2.running_mean", "layer2.7.bn2.running_var", "layer2.7.bn2.num_batches_tracked", "layer2.7.conv3.weight", "layer2.7.bn3.weight", "layer2.7.bn3.bias", "layer2.7.bn3.running_mean", "layer2.7.bn3.running_var", "layer2.7.bn3.num_batches_tracked", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn1.num_batches_tracked", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.bn2.num_batches_tracked", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.bn3.num_batches_tracked", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.0.downsample.1.num_batches_tracked", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn1.num_batches_tracked", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.bn2.num_batches_tracked", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.bn3.num_batches_tracked", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn1.num_batches_tracked", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.bn2.num_batches_tracked", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.bn3.num_batches_tracked", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn1.num_batches_tracked", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.bn2.num_batches_tracked", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.bn3.num_batches_tracked", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn1.num_batches_tracked", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.bn2.num_batches_tracked", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.bn3.num_batches_tracked", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn1.num_batches_tracked", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.bn2.num_batches_tracked", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.bn3.num_batches_tracked", "layer3.6.conv1.weight", "layer3.6.bn1.weight", "layer3.6.bn1.bias", "layer3.6.bn1.running_mean", "layer3.6.bn1.running_var", "layer3.6.bn1.num_batches_tracked", "layer3.6.conv2.weight", "layer3.6.bn2.weight", "layer3.6.bn2.bias", "layer3.6.bn2.running_mean", "layer3.6.bn2.running_var", "layer3.6.bn2.num_batches_tracked", "layer3.6.conv3.weight", "layer3.6.bn3.weight", "layer3.6.bn3.bias", "layer3.6.bn3.running_mean", "layer3.6.bn3.running_var", "layer3.6.bn3.num_batches_tracked", "layer3.7.conv1.weight", "layer3.7.bn1.weight", "layer3.7.bn1.bias", "layer3.7.bn1.running_mean", "layer3.7.bn1.running_var", "layer3.7.bn1.num_batches_tracked", "layer3.7.conv2.weight", "layer3.7.bn2.weight", "layer3.7.bn2.bias", "layer3.7.bn2.running_mean", "layer3.7.bn2.running_var", "layer3.7.bn2.num_batches_tracked", "layer3.7.conv3.weight", "layer3.7.bn3.weight", "layer3.7.bn3.bias", "layer3.7.bn3.running_mean", "layer3.7.bn3.running_var", "layer3.7.bn3.num_batches_tracked", "layer3.8.conv1.weight", "layer3.8.bn1.weight", "layer3.8.bn1.bias", "layer3.8.bn1.running_mean", "layer3.8.bn1.running_var", "layer3.8.bn1.num_batches_tracked", "layer3.8.conv2.weight", "layer3.8.bn2.weight", "layer3.8.bn2.bias", "layer3.8.bn2.running_mean", "layer3.8.bn2.running_var", "layer3.8.bn2.num_batches_tracked", "layer3.8.conv3.weight", "layer3.8.bn3.weight", "layer3.8.bn3.bias", "layer3.8.bn3.running_mean", "layer3.8.bn3.running_var", "layer3.8.bn3.num_batches_tracked", "layer3.9.conv1.weight", "layer3.9.bn1.weight", "layer3.9.bn1.bias", "layer3.9.bn1.running_mean", "layer3.9.bn1.running_var", "layer3.9.bn1.num_batches_tracked", "layer3.9.conv2.weight", "layer3.9.bn2.weight", "layer3.9.bn2.bias", "layer3.9.bn2.running_mean", "layer3.9.bn2.running_var", "layer3.9.bn2.num_batches_tracked", "layer3.9.conv3.weight", "layer3.9.bn3.weight", "layer3.9.bn3.bias", "layer3.9.bn3.running_mean", "layer3.9.bn3.running_var", "layer3.9.bn3.num_batches_tracked", "layer3.10.conv1.weight", "layer3.10.bn1.weight", "layer3.10.bn1.bias", "layer3.10.bn1.running_mean", "layer3.10.bn1.running_var", "layer3.10.bn1.num_batches_tracked", "layer3.10.conv2.weight", "layer3.10.bn2.weight", "layer3.10.bn2.bias", "layer3.10.bn2.running_mean", "layer3.10.bn2.running_var", "layer3.10.bn2.num_batches_tracked", "layer3.10.conv3.weight", "layer3.10.bn3.weight", "layer3.10.bn3.bias", "layer3.10.bn3.running_mean", "layer3.10.bn3.running_var", "layer3.10.bn3.num_batches_tracked", "layer3.11.conv1.weight", "layer3.11.bn1.weight", "layer3.11.bn1.bias", "layer3.11.bn1.running_mean", "layer3.11.bn1.running_var", "layer3.11.bn1.num_batches_tracked", "layer3.11.conv2.weight", "layer3.11.bn2.weight", "layer3.11.bn2.bias", "layer3.11.bn2.running_mean", "layer3.11.bn2.running_var", "layer3.11.bn2.num_batches_tracked", "layer3.11.conv3.weight", "layer3.11.bn3.weight", "layer3.11.bn3.bias", "layer3.11.bn3.running_mean", "layer3.11.bn3.running_var", "layer3.11.bn3.num_batches_tracked", "layer3.12.conv1.weight", "layer3.12.bn1.weight", "layer3.12.bn1.bias", "layer3.12.bn1.running_mean", "layer3.12.bn1.running_var", "layer3.12.bn1.num_batches_tracked", "layer3.12.conv2.weight", "layer3.12.bn2.weight", "layer3.12.bn2.bias", "layer3.12.bn2.running_mean", "layer3.12.bn2.running_var", "layer3.12.bn2.num_batches_tracked", "layer3.12.conv3.weight", "layer3.12.bn3.weight", "layer3.12.bn3.bias", "layer3.12.bn3.running_mean", "layer3.12.bn3.running_var", "layer3.12.bn3.num_batches_tracked", "layer3.13.conv1.weight", "layer3.13.bn1.weight", "layer3.13.bn1.bias", "layer3.13.bn1.running_mean", "layer3.13.bn1.running_var", "layer3.13.bn1.num_batches_tracked", "layer3.13.conv2.weight", "layer3.13.bn2.weight", "layer3.13.bn2.bias", "layer3.13.bn2.running_mean", "layer3.13.bn2.running_var", "layer3.13.bn2.num_batches_tracked", "layer3.13.conv3.weight", "layer3.13.bn3.weight", "layer3.13.bn3.bias", "layer3.13.bn3.running_mean", "layer3.13.bn3.running_var", "layer3.13.bn3.num_batches_tracked", "layer3.14.conv1.weight", "layer3.14.bn1.weight", "layer3.14.bn1.bias", "layer3.14.bn1.running_mean", "layer3.14.bn1.running_var", "layer3.14.bn1.num_batches_tracked", "layer3.14.conv2.weight", "layer3.14.bn2.weight", "layer3.14.bn2.bias", "layer3.14.bn2.running_mean", "layer3.14.bn2.running_var", "layer3.14.bn2.num_batches_tracked", "layer3.14.conv3.weight", "layer3.14.bn3.weight", "layer3.14.bn3.bias", "layer3.14.bn3.running_mean", "layer3.14.bn3.running_var", "layer3.14.bn3.num_batches_tracked", "layer3.15.conv1.weight", "layer3.15.bn1.weight", "layer3.15.bn1.bias", "layer3.15.bn1.running_mean", "layer3.15.bn1.running_var", "layer3.15.bn1.num_batches_tracked", "layer3.15.conv2.weight", "layer3.15.bn2.weight", "layer3.15.bn2.bias", "layer3.15.bn2.running_mean", "layer3.15.bn2.running_var", "layer3.15.bn2.num_batches_tracked", "layer3.15.conv3.weight", "layer3.15.bn3.weight", "layer3.15.bn3.bias", "layer3.15.bn3.running_mean", "layer3.15.bn3.running_var", "layer3.15.bn3.num_batches_tracked", "layer3.16.conv1.weight", "layer3.16.bn1.weight", "layer3.16.bn1.bias", "layer3.16.bn1.running_mean", "layer3.16.bn1.running_var", "layer3.16.bn1.num_batches_tracked", "layer3.16.conv2.weight", "layer3.16.bn2.weight", "layer3.16.bn2.bias", "layer3.16.bn2.running_mean", "layer3.16.bn2.running_var", "layer3.16.bn2.num_batches_tracked", "layer3.16.conv3.weight", "layer3.16.bn3.weight", "layer3.16.bn3.bias", "layer3.16.bn3.running_mean", "layer3.16.bn3.running_var", "layer3.16.bn3.num_batches_tracked", "layer3.17.conv1.weight", "layer3.17.bn1.weight", "layer3.17.bn1.bias", "layer3.17.bn1.running_mean", "layer3.17.bn1.running_var", "layer3.17.bn1.num_batches_tracked", "layer3.17.conv2.weight", "layer3.17.bn2.weight", "layer3.17.bn2.bias", "layer3.17.bn2.running_mean", "layer3.17.bn2.running_var", "layer3.17.bn2.num_batches_tracked", "layer3.17.conv3.weight", "layer3.17.bn3.weight", "layer3.17.bn3.bias", "layer3.17.bn3.running_mean", "layer3.17.bn3.running_var", "layer3.17.bn3.num_batches_tracked", "layer3.18.conv1.weight", "layer3.18.bn1.weight", "layer3.18.bn1.bias", "layer3.18.bn1.running_mean", "layer3.18.bn1.running_var", "layer3.18.bn1.num_batches_tracked", "layer3.18.conv2.weight", "layer3.18.bn2.weight", "layer3.18.bn2.bias", "layer3.18.bn2.running_mean", "layer3.18.bn2.running_var", "layer3.18.bn2.num_batches_tracked", "layer3.18.conv3.weight", "layer3.18.bn3.weight", "layer3.18.bn3.bias", "layer3.18.bn3.running_mean", "layer3.18.bn3.running_var", "layer3.18.bn3.num_batches_tracked", "layer3.19.conv1.weight", "layer3.19.bn1.weight", "layer3.19.bn1.bias", "layer3.19.bn1.running_mean", "layer3.19.bn1.running_var", "layer3.19.bn1.num_batches_tracked", "layer3.19.conv2.weight", "layer3.19.bn2.weight", "layer3.19.bn2.bias", "layer3.19.bn2.running_mean", "layer3.19.bn2.running_var", "layer3.19.bn2.num_batches_tracked", "layer3.19.conv3.weight", "layer3.19.bn3.weight", "layer3.19.bn3.bias", "layer3.19.bn3.running_mean", "layer3.19.bn3.running_var", "layer3.19.bn3.num_batches_tracked", "layer3.20.conv1.weight", "layer3.20.bn1.weight", "layer3.20.bn1.bias", "layer3.20.bn1.running_mean", "layer3.20.bn1.running_var", "layer3.20.bn1.num_batches_tracked", "layer3.20.conv2.weight", "layer3.20.bn2.weight", "layer3.20.bn2.bias", "layer3.20.bn2.running_mean", "layer3.20.bn2.running_var", "layer3.20.bn2.num_batches_tracked", "layer3.20.conv3.weight", "layer3.20.bn3.weight", "layer3.20.bn3.bias", "layer3.20.bn3.running_mean", "layer3.20.bn3.running_var", "layer3.20.bn3.num_batches_tracked", "layer3.21.conv1.weight", "layer3.21.bn1.weight", "layer3.21.bn1.bias", "layer3.21.bn1.running_mean", "layer3.21.bn1.running_var", "layer3.21.bn1.num_batches_tracked", "layer3.21.conv2.weight", "layer3.21.bn2.weight", "layer3.21.bn2.bias", "layer3.21.bn2.running_mean", "layer3.21.bn2.running_var", "layer3.21.bn2.num_batches_tracked", "layer3.21.conv3.weight", "layer3.21.bn3.weight", "layer3.21.bn3.bias", "layer3.21.bn3.running_mean", "layer3.21.bn3.running_var", "layer3.21.bn3.num_batches_tracked", "layer3.22.conv1.weight", "layer3.22.bn1.weight", "layer3.22.bn1.bias", "layer3.22.bn1.running_mean", "layer3.22.bn1.running_var", "layer3.22.bn1.num_batches_tracked", "layer3.22.conv2.weight", "layer3.22.bn2.weight", "layer3.22.bn2.bias", "layer3.22.bn2.running_mean", "layer3.22.bn2.running_var", "layer3.22.bn2.num_batches_tracked", "layer3.22.conv3.weight", "layer3.22.bn3.weight", "layer3.22.bn3.bias", "layer3.22.bn3.running_mean", "layer3.22.bn3.running_var", "layer3.22.bn3.num_batches_tracked", "layer3.23.conv1.weight", "layer3.23.bn1.weight", "layer3.23.bn1.bias", "layer3.23.bn1.running_mean", "layer3.23.bn1.running_var", "layer3.23.bn1.num_batches_tracked", "layer3.23.conv2.weight", "layer3.23.bn2.weight", "layer3.23.bn2.bias", "layer3.23.bn2.running_mean", "layer3.23.bn2.running_var", "layer3.23.bn2.num_batches_tracked", "layer3.23.conv3.weight", "layer3.23.bn3.weight", "layer3.23.bn3.bias", "layer3.23.bn3.running_mean", "layer3.23.bn3.running_var", "layer3.23.bn3.num_batches_tracked", "layer3.24.conv1.weight", "layer3.24.bn1.weight", "layer3.24.bn1.bias", "layer3.24.bn1.running_mean", "layer3.24.bn1.running_var", "layer3.24.bn1.num_batches_tracked", "layer3.24.conv2.weight", "layer3.24.bn2.weight", "layer3.24.bn2.bias", "layer3.24.bn2.running_mean", "layer3.24.bn2.running_var", "layer3.24.bn2.num_batches_tracked", "layer3.24.conv3.weight", "layer3.24.bn3.weight", "layer3.24.bn3.bias", "layer3.24.bn3.running_mean", "layer3.24.bn3.running_var", "layer3.24.bn3.num_batches_tracked", "layer3.25.conv1.weight", "layer3.25.bn1.weight", "layer3.25.bn1.bias", "layer3.25.bn1.running_mean", "layer3.25.bn1.running_var", "layer3.25.bn1.num_batches_tracked", "layer3.25.conv2.weight", "layer3.25.bn2.weight", "layer3.25.bn2.bias", "layer3.25.bn2.running_mean", "layer3.25.bn2.running_var", "layer3.25.bn2.num_batches_tracked", "layer3.25.conv3.weight", "layer3.25.bn3.weight", "layer3.25.bn3.bias", "layer3.25.bn3.running_mean", "layer3.25.bn3.running_var", "layer3.25.bn3.num_batches_tracked", "layer3.26.conv1.weight", "layer3.26.bn1.weight", "layer3.26.bn1.bias", "layer3.26.bn1.running_mean", "layer3.26.bn1.running_var", "layer3.26.bn1.num_batches_tracked", "layer3.26.conv2.weight", "layer3.26.bn2.weight", "layer3.26.bn2.bias", "layer3.26.bn2.running_mean", "layer3.26.bn2.running_var", "layer3.26.bn2.num_batches_tracked", "layer3.26.conv3.weight", "layer3.26.bn3.weight", "layer3.26.bn3.bias", "layer3.26.bn3.running_mean", "layer3.26.bn3.running_var", "layer3.26.bn3.num_batches_tracked", "layer3.27.conv1.weight", "layer3.27.bn1.weight", "layer3.27.bn1.bias", "layer3.27.bn1.running_mean", "layer3.27.bn1.running_var", "layer3.27.bn1.num_batches_tracked", "layer3.27.conv2.weight", "layer3.27.bn2.weight", "layer3.27.bn2.bias", "layer3.27.bn2.running_mean", "layer3.27.bn2.running_var", "layer3.27.bn2.num_batches_tracked", "layer3.27.conv3.weight", "layer3.27.bn3.weight", "layer3.27.bn3.bias", "layer3.27.bn3.running_mean", "layer3.27.bn3.running_var", "layer3.27.bn3.num_batches_tracked", "layer3.28.conv1.weight", "layer3.28.bn1.weight", "layer3.28.bn1.bias", "layer3.28.bn1.running_mean", "layer3.28.bn1.running_var", "layer3.28.bn1.num_batches_tracked", "layer3.28.conv2.weight", "layer3.28.bn2.weight", "layer3.28.bn2.bias", "layer3.28.bn2.running_mean", "layer3.28.bn2.running_var", "layer3.28.bn2.num_batches_tracked", "layer3.28.conv3.weight", "layer3.28.bn3.weight", "layer3.28.bn3.bias", "layer3.28.bn3.running_mean", "layer3.28.bn3.running_var", "layer3.28.bn3.num_batches_tracked", "layer3.29.conv1.weight", "layer3.29.bn1.weight", "layer3.29.bn1.bias", "layer3.29.bn1.running_mean", "layer3.29.bn1.running_var", "layer3.29.bn1.num_batches_tracked", "layer3.29.conv2.weight", "layer3.29.bn2.weight", "layer3.29.bn2.bias", "layer3.29.bn2.running_mean", "layer3.29.bn2.running_var", "layer3.29.bn2.num_batches_tracked", "layer3.29.conv3.weight", "layer3.29.bn3.weight", "layer3.29.bn3.bias", "layer3.29.bn3.running_mean", "layer3.29.bn3.running_var", "layer3.29.bn3.num_batches_tracked", "layer3.30.conv1.weight", "layer3.30.bn1.weight", "layer3.30.bn1.bias", "layer3.30.bn1.running_mean", "layer3.30.bn1.running_var", "layer3.30.bn1.num_batches_tracked", "layer3.30.conv2.weight", "layer3.30.bn2.weight", "layer3.30.bn2.bias", "layer3.30.bn2.running_mean", "layer3.30.bn2.running_var", "layer3.30.bn2.num_batches_tracked", "layer3.30.conv3.weight", "layer3.30.bn3.weight", "layer3.30.bn3.bias", "layer3.30.bn3.running_mean", "layer3.30.bn3.running_var", "layer3.30.bn3.num_batches_tracked", "layer3.31.conv1.weight", "layer3.31.bn1.weight", "layer3.31.bn1.bias", "layer3.31.bn1.running_mean", "layer3.31.bn1.running_var", "layer3.31.bn1.num_batches_tracked", "layer3.31.conv2.weight", "layer3.31.bn2.weight", "layer3.31.bn2.bias", "layer3.31.bn2.running_mean", "layer3.31.bn2.running_var", "layer3.31.bn2.num_batches_tracked", "layer3.31.conv3.weight", "layer3.31.bn3.weight", "layer3.31.bn3.bias", "layer3.31.bn3.running_mean", "layer3.31.bn3.running_var", "layer3.31.bn3.num_batches_tracked", "layer3.32.conv1.weight", "layer3.32.bn1.weight", "layer3.32.bn1.bias", "layer3.32.bn1.running_mean", "layer3.32.bn1.running_var", "layer3.32.bn1.num_batches_tracked", "layer3.32.conv2.weight", "layer3.32.bn2.weight", "layer3.32.bn2.bias", "layer3.32.bn2.running_mean", "layer3.32.bn2.running_var", "layer3.32.bn2.num_batches_tracked", "layer3.32.conv3.weight", "layer3.32.bn3.weight", "layer3.32.bn3.bias", "layer3.32.bn3.running_mean", "layer3.32.bn3.running_var", "layer3.32.bn3.num_batches_tracked", "layer3.33.conv1.weight", "layer3.33.bn1.weight", "layer3.33.bn1.bias", "layer3.33.bn1.running_mean", "layer3.33.bn1.running_var", "layer3.33.bn1.num_batches_tracked", "layer3.33.conv2.weight", "layer3.33.bn2.weight", "layer3.33.bn2.bias", "layer3.33.bn2.running_mean", "layer3.33.bn2.running_var", "layer3.33.bn2.num_batches_tracked", "layer3.33.conv3.weight", "layer3.33.bn3.weight", "layer3.33.bn3.bias", "layer3.33.bn3.running_mean", "layer3.33.bn3.running_var", "layer3.33.bn3.num_batches_tracked", "layer3.34.conv1.weight", "layer3.34.bn1.weight", "layer3.34.bn1.bias", "layer3.34.bn1.running_mean", "layer3.34.bn1.running_var", "layer3.34.bn1.num_batches_tracked", "layer3.34.conv2.weight", "layer3.34.bn2.weight", "layer3.34.bn2.bias", "layer3.34.bn2.running_mean", "layer3.34.bn2.running_var", "layer3.34.bn2.num_batches_tracked", "layer3.34.conv3.weight", "layer3.34.bn3.weight", "layer3.34.bn3.bias", "layer3.34.bn3.running_mean", "layer3.34.bn3.running_var", "layer3.34.bn3.num_batches_tracked", "layer3.35.conv1.weight", "layer3.35.bn1.weight", "layer3.35.bn1.bias", "layer3.35.bn1.running_mean", "layer3.35.bn1.running_var", "layer3.35.bn1.num_batches_tracked", "layer3.35.conv2.weight", "layer3.35.bn2.weight", "layer3.35.bn2.bias", "layer3.35.bn2.running_mean", "layer3.35.bn2.running_var", "layer3.35.bn2.num_batches_tracked", "layer3.35.conv3.weight", "layer3.35.bn3.weight", "layer3.35.bn3.bias", "layer3.35.bn3.running_mean", "layer3.35.bn3.running_var", "layer3.35.bn3.num_batches_tracked", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.num_batches_tracked", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.num_batches_tracked", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.bn3.num_batches_tracked", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.num_batches_tracked", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.num_batches_tracked", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.num_batches_tracked", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.bn3.num_batches_tracked", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn1.num_batches_tracked", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.bn2.num_batches_tracked", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.bn3.num_batches_tracked". 

In [None]:
#Step 5: Greedy Decode / Beam Search to generate caption
def greedy_decode(model, feat, vocab, max_len=45):
    idx2word = {v: k for k, v in vocab.items()}
    generated = [vocab['<SOS>']]
    for _ in range(max_len):
        trg = torch.tensor(generated).unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(feat, trg)
        next_token = out[0, -1].argmax().item()
        if next_token == vocab['<EOS>']:
            break
        generated.append(next_token)

    return ' '.join([idx2word[idx] for idx in generated[1:]])

# Generate Caption
caption = greedy_decode(model, video_feat, vocab)
print("🎬 Predicted Caption:", caption)


In [65]:
#Main tumhare training phase me auto-save logic de du
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

def train_transformer(model, train_dataset, val_dataset, vocab, device,
                      num_epochs=20, batch_size=8, learning_rate=1e-4,
                      checkpoint_dir='/content/drive/MyDrive/transformer_checkpoints',
                      save_every=5):  # ← Save after every N epochs

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab['<PAD>'])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0

        print(f"\n🚀 Epoch {epoch}/{num_epochs}")

        for video_feats, captions in tqdm(train_loader, desc="Training"):
            video_feats, captions = video_feats.to(device), captions.to(device)
            optimizer.zero_grad()

            tgt_input = captions[:, :-1]
            tgt_output = captions[:, 1:]

            logits = model(video_feats, tgt_input)
            loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"✅ Training Loss: {avg_loss:.4f}")

        # === Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for video_feats, captions in val_loader:
                video_feats, captions = video_feats.to(device), captions.to(device)
                tgt_input = captions[:, :-1]
                tgt_output = captions[:, 1:]

                logits = model(video_feats, tgt_input)
                loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"🧪 Validation Loss: {avg_val_loss:.4f}")

        # === Auto-save checkpoint
        if epoch % save_every == 0 or epoch == num_epochs:
            ckpt_path = os.path.join(checkpoint_dir, f"transformer_decoder_epoch_{epoch}.pt")
            torch.save(model.state_dict(), ckpt_path)
            print(f"💾 Saved checkpoint: {ckpt_path}")


In [66]:
#train
train_transformer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    vocab=vocab,
    device=device,
    num_epochs=30,
    batch_size=8,
    learning_rate=1e-4,
    checkpoint_dir="/content/drive/MyDrive/transformer_checkpoints",
    save_every=5  # Save at epochs 5, 10, 15, ...
)



🚀 Epoch 1/30


Training: 100%|██████████| 157/157 [00:08<00:00, 17.95it/s]


✅ Training Loss: 5.4108
🧪 Validation Loss: 4.5337

🚀 Epoch 2/30


Training: 100%|██████████| 157/157 [00:07<00:00, 19.83it/s]


✅ Training Loss: 4.3264
🧪 Validation Loss: 4.3262

🚀 Epoch 3/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.59it/s]


✅ Training Loss: 4.0294
🧪 Validation Loss: 3.7730

🚀 Epoch 4/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.04it/s]


✅ Training Loss: 3.8594
🧪 Validation Loss: 3.8122

🚀 Epoch 5/30


Training: 100%|██████████| 157/157 [00:07<00:00, 20.24it/s]


✅ Training Loss: 3.7236
🧪 Validation Loss: 3.7753
💾 Saved checkpoint: /content/drive/MyDrive/transformer_checkpoints/transformer_decoder_epoch_5.pt

🚀 Epoch 6/30


Training: 100%|██████████| 157/157 [00:08<00:00, 18.67it/s]


✅ Training Loss: 3.6082
🧪 Validation Loss: 3.3932

🚀 Epoch 7/30


Training: 100%|██████████| 157/157 [00:08<00:00, 18.76it/s]


✅ Training Loss: 3.5553
🧪 Validation Loss: 3.8743

🚀 Epoch 8/30


Training: 100%|██████████| 157/157 [00:07<00:00, 20.07it/s]


✅ Training Loss: 3.5894
🧪 Validation Loss: 3.3342

🚀 Epoch 9/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.15it/s]


✅ Training Loss: 3.4466
🧪 Validation Loss: 3.4264

🚀 Epoch 10/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.20it/s]


✅ Training Loss: 3.3852
🧪 Validation Loss: 3.5293
💾 Saved checkpoint: /content/drive/MyDrive/transformer_checkpoints/transformer_decoder_epoch_10.pt

🚀 Epoch 11/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.45it/s]


✅ Training Loss: 3.2952
🧪 Validation Loss: 3.6353

🚀 Epoch 12/30


Training: 100%|██████████| 157/157 [00:08<00:00, 18.04it/s]


✅ Training Loss: 3.2118
🧪 Validation Loss: 2.9639

🚀 Epoch 13/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.16it/s]


✅ Training Loss: 3.1912
🧪 Validation Loss: 3.4093

🚀 Epoch 14/30


Training: 100%|██████████| 157/157 [00:07<00:00, 20.33it/s]


✅ Training Loss: 3.1633
🧪 Validation Loss: 3.2087

🚀 Epoch 15/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.05it/s]


✅ Training Loss: 3.1003
🧪 Validation Loss: 3.1731
💾 Saved checkpoint: /content/drive/MyDrive/transformer_checkpoints/transformer_decoder_epoch_15.pt

🚀 Epoch 16/30


Training: 100%|██████████| 157/157 [00:08<00:00, 17.74it/s]


✅ Training Loss: 3.1000
🧪 Validation Loss: 3.1552

🚀 Epoch 17/30


Training: 100%|██████████| 157/157 [00:07<00:00, 20.28it/s]


✅ Training Loss: 3.0357
🧪 Validation Loss: 3.2603

🚀 Epoch 18/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.14it/s]


✅ Training Loss: 2.9433
🧪 Validation Loss: 3.0677

🚀 Epoch 19/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.48it/s]


✅ Training Loss: 2.8947
🧪 Validation Loss: 3.2817

🚀 Epoch 20/30


Training: 100%|██████████| 157/157 [00:07<00:00, 20.41it/s]


✅ Training Loss: 2.9638
🧪 Validation Loss: 2.9872
💾 Saved checkpoint: /content/drive/MyDrive/transformer_checkpoints/transformer_decoder_epoch_20.pt

🚀 Epoch 21/30


Training: 100%|██████████| 157/157 [00:08<00:00, 18.67it/s]


✅ Training Loss: 2.8218
🧪 Validation Loss: 3.2754

🚀 Epoch 22/30


Training: 100%|██████████| 157/157 [00:08<00:00, 18.46it/s]


✅ Training Loss: 2.8400
🧪 Validation Loss: 3.2090

🚀 Epoch 23/30


Training: 100%|██████████| 157/157 [00:07<00:00, 20.24it/s]


✅ Training Loss: 2.8855
🧪 Validation Loss: 3.1425

🚀 Epoch 24/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.17it/s]


✅ Training Loss: 2.8026
🧪 Validation Loss: 3.3355

🚀 Epoch 25/30


Training: 100%|██████████| 157/157 [00:07<00:00, 19.93it/s]


✅ Training Loss: 2.7615
🧪 Validation Loss: 3.0278
💾 Saved checkpoint: /content/drive/MyDrive/transformer_checkpoints/transformer_decoder_epoch_25.pt

🚀 Epoch 26/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.12it/s]


✅ Training Loss: 2.7005
🧪 Validation Loss: 3.2970

🚀 Epoch 27/30


Training: 100%|██████████| 157/157 [00:08<00:00, 18.78it/s]


✅ Training Loss: 2.7575
🧪 Validation Loss: 2.9026

🚀 Epoch 28/30


Training: 100%|██████████| 157/157 [00:07<00:00, 19.90it/s]


✅ Training Loss: 2.7813
🧪 Validation Loss: 3.1231

🚀 Epoch 29/30


Training: 100%|██████████| 157/157 [00:07<00:00, 19.70it/s]


✅ Training Loss: 2.6372
🧪 Validation Loss: 2.7363

🚀 Epoch 30/30


Training: 100%|██████████| 157/157 [00:08<00:00, 19.12it/s]


✅ Training Loss: 2.6946
🧪 Validation Loss: 2.6889
💾 Saved checkpoint: /content/drive/MyDrive/transformer_checkpoints/transformer_decoder_epoch_30.pt


In [70]:
video_path = '/content/04.mp4'                # Local video
frame_dir = '/content/frames_extracted'       # Frames folder
feature_path = '/content/04_feat.npy'         # Local feature file
vocab_path = '/content/drive/MyDrive/msvd_split/vocab.json'
model_path = '/content/drive/MyDrive/transformer_checkpoints/transformer_decoder_epoch_30.pt'


In [71]:
import os
import shutil
import subprocess

if os.path.exists(frame_dir):
    shutil.rmtree(frame_dir)
os.makedirs(frame_dir, exist_ok=True)

# Extract at 2 fps
!ffmpeg -i "$video_path" -vf "fps=2" "$frame_dir/frame_%03d.jpg" -hide_banner -loglevel error


In [72]:
from torchvision import transforms
from PIL import Image
import torch
import glob

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))
frames_tensor = torch.stack([transform(Image.open(fp)) for fp in frame_paths]).to(device)
print("🎞️ Loaded frames:", frames_tensor.shape)


🎞️ Loaded frames: torch.Size([60, 3, 224, 224])


In [73]:
import torchvision.models as models
import torch.nn as nn
import numpy as np

resnet = models.resnet152(pretrained=True)
resnet.fc = nn.Identity()
resnet = resnet.to(device)
resnet.eval()

with torch.no_grad():
    feats = resnet(frames_tensor)  # [T, 2048]

# Save locally
np.save(feature_path, feats.cpu().numpy())
print(f"✅ Saved features to {feature_path}")


Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:01<00:00, 172MB/s]


✅ Saved features to /content/04_feat.npy


In [74]:
feat_arr = np.load(feature_path)
video_feat = torch.tensor(feat_arr).unsqueeze(0).to(torch.float32).to(device)
print("📦 Video Feature Shape:", video_feat.shape)


📦 Video Feature Shape: torch.Size([1, 60, 2048])


In [75]:
import json

with open(vocab_path, 'r') as f:
    vocab = json.load(f)
rev_vocab = {v: k for k, v in vocab.items()}


In [76]:
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].to(x.device)
        return self.dropout(x)

class VideoTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=3, dim_ff=2048, max_len=45, dropout=0.1, input_feat_dim=2048):
        super().__init__()
        self.input_fc = nn.Linear(input_feat_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_ff, dropout)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.max_len = max_len
        self.d_model = d_model

    def forward(self, vid_feats, tgt_input):
        vid_feats = self.input_fc(vid_feats)
        memory = self.pos_encoder(vid_feats)
        tgt_embed = self.embedding(tgt_input) * (self.d_model ** 0.5)
        tgt_embed = self.pos_encoder(tgt_embed)
        output = self.decoder(tgt_embed.transpose(0, 1), memory.transpose(0, 1))
        return self.fc_out(output.transpose(0, 1))


In [77]:
model = VideoTransformer(
    vocab_size=len(vocab),
    d_model=512,
    nhead=8,
    num_layers=3,
    max_len=45,
    input_feat_dim=2048
).to(device)

model.load_state_dict(torch.load(model_path))
model.eval()


RuntimeError: Error(s) in loading state_dict for VideoTransformer:
	Missing key(s) in state_dict: "input_fc.weight", "input_fc.bias", "fc_out.weight", "fc_out.bias". 
	Unexpected key(s) in state_dict: "pos_enc.pe", "vid_fc.weight", "vid_fc.bias", "output.weight", "output.bias". 

In [78]:
torch.save(model.state_dict(), '/content/drive/MyDrive/transformer_checkpoints/video_transformer_final.pt')


In [79]:
state_dict = torch.load(model_path)
model.load_state_dict(state_dict, strict=False)  # ⚠️ Not recommended for deployment


_IncompatibleKeys(missing_keys=['input_fc.weight', 'input_fc.bias', 'fc_out.weight', 'fc_out.bias'], unexpected_keys=['pos_enc.pe', 'vid_fc.weight', 'vid_fc.bias', 'output.weight', 'output.bias'])

In [80]:
# Save after training
torch.save(model.state_dict(), '/content/drive/MyDrive/transformer_checkpoints/video_transformer_final.pt')


In [None]:
# NEXT MODEL

