# Import necessary libraries


In [8]:
from config import TrainConfig as C
from models.abd_transformer import ABDTransformer
import torch
import h5py
from utils import dict_to_cls

# Load checkpoint and config


In [9]:
checkpoint = torch.load("checkpoints/best.ckpt", map_location="cpu")
config = dict_to_cls(checkpoint['config'])

In [10]:
%%capture
!pip install pandas torchvision

In [11]:
from loader.MSVD import MSVD
corpus = MSVD(config)

# Build Models


In [12]:
vocab = corpus.vocab
""" Build Models """
try:
    model = ABDTransformer(vocab, config.feat.size, config.transformer.d_model, config.transformer.d_ff,
                           config.transformer.n_heads, config.transformer.n_layers, config.transformer.dropout,
                           config.feat.feature_mode, n_heads_big=config.transformer.n_heads_big,
                           select_num=config.transformer.select_num)
except:
    model = ABDTransformer(vocab, config.feat.size, config.transformer.d_model, config.transformer.d_ff,
                           config.transformer.n_heads, config.transformer.n_layers, config.transformer.dropout,
                           config.feat.feature_mode, n_heads_big=config.transformer.n_heads_big)
model.load_state_dict(checkpoint['abd_transformer'])
model.device = "cpu"

# Move model to cpu
model = model.to("cpu")

# Load extracted features


In [13]:
# Load saved features
image_feats = torch.load('features/image_feats.pt', map_location="cpu")

# Load motion features from HDF5 file with error handling
try:
    with h5py.File('features/motion_feats.hdf5', 'r') as f:
        # Print available keys to see the structure
        print("Available keys in HDF5 file:", list(f.keys()))
        
        # Try common dataset key names
        if 'features' in f:
            motion_feats = torch.tensor(f['features'][:], dtype=torch.float32)
        elif 'motion_features' in f:
            motion_feats = torch.tensor(f['motion_features'][:], dtype=torch.float32)
        elif 'data' in f:
            motion_feats = torch.tensor(f['data'][:], dtype=torch.float32)
        else:
            # Use the first available key
            first_key = list(f.keys())[0]
            print(f"Using key: {first_key}")
            motion_feats = torch.tensor(f[first_key][:], dtype=torch.float32)
            
except FileNotFoundError:
    print("HDF5 file not found, falling back to .pt file")
    motion_feats = torch.load('features/motion_feats.pt', map_location="cpu")
except Exception as e:
    print(f"Error loading HDF5 file: {e}")
    print("Falling back to .pt file")
    motion_feats = torch.load('features/motion_feats.pt', map_location="cpu")

# Add batch dimension if needed (for shape [50, 1024] -> [1, 50, 1024])
print(f"Motion features shape before unsqueeze: {motion_feats.shape}")
if len(motion_feats.shape) == 2:
    motion_feats = motion_feats.unsqueeze(0)
    print(f"Motion features shape after unsqueeze: {motion_feats.shape}")

obect_feats = torch.load('features/object_feats.pt', map_location="cpu")
rel_feats = torch.load('features/rel_feats.pt', map_location="cpu")

print("Image features shape:", image_feats.shape)
print("Motion features shape:", motion_feats.shape)
print("Object features shape:", obect_feats.shape)
print("Relation features shape:", rel_feats.shape)

Available keys in HDF5 file: ['lifting']
Using key: lifting
Motion features shape before unsqueeze: torch.Size([50, 1024])
Motion features shape after unsqueeze: torch.Size([1, 50, 1024])
Image features shape: torch.Size([1, 50, 1536])
Motion features shape: torch.Size([1, 50, 1024])
Object features shape: torch.Size([1, 50, 1028])
Relation features shape: torch.Size([1, 50, 300])


# Inference with beam search


In [14]:
%%time
model.eval()
beam_size = config.beam_size
max_len = config.loader.max_caption_len
feature_mode = config.feat.feature_mode
feats = (image_feats, motion_feats, obect_feats, rel_feats)
with torch.no_grad():
    r2l_captions, l2r_captions = model.beam_search_decode(feats, beam_size, max_len)
    # r2l_captions = [idxs_to_sentence(caption, vocab.idx2word, BOS_idx) for caption in r2l_captions]
    l2r_captions = [" ".join(caption[0].value) for caption in l2r_captions]
    r2l_captions = [" ".join(caption[0].value) for caption in r2l_captions]
    
    print(f"Left to Right Captions: {l2r_captions}")

Left to Right Captions: ['the person is doing the something']
CPU times: total: 8.28 s
Wall time: 6.29 s
