In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import pandas as pd
import numpy as np
from easydict import EasyDict
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm as tqdm
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType

class MyMSRVTT_DataLoader(Dataset):
    """MSRVTT dataset loader."""
    def __init__(
            self,
            csv_path,
            features_path,
    ):
        self.data = pd.read_csv(csv_path)
        self.features_path = features_path
       

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        video_id = self.data['video_id'].values[idx]
        sentence = self.data['sentence'].values[idx]

        video_path = os.path.join(self.features_path, "{}.mp4".format(video_id))
        return sentence, video_path
    

def get_args_msrvtt():
    # build args
    args = {
        "val_csv": '/raid/1moritz/datasets//MSRVTT/MSRVTT_JSFUSION_test.csv',
        "features_path": '/raid/1moritz/datasets//MSRVTT/MSRVTT_Videos',
        "batch_size_val": 8,
        "num_thread_reader": 1,
        "cache_dir": '/raid/1moritz/models/languagebind/downloaded_weights',
    }
    args = EasyDict(args)
    return args

def run_msrvtt_eval(model: imagebind_model.ImageBindModel, dataloader: DataLoader, device: torch.device):
    batch_sentences_embeddings, batch_videos_embeddings = [], []
    # Calculate embeddings
    for bid, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        sentences, video_paths = batch

        if not isinstance(sentences, list):
            sentences = list(sentences)
        if not isinstance(video_paths, list):
            video_paths= list(video_paths)

        # Load data
        inputs = {
            ModalityType.TEXT: data.load_and_transform_text(sentences, device),
            ModalityType.VISION: data.load_and_transform_video_data(video_paths, device),
        }
        
        with torch.no_grad():
            embeddings = model(inputs)

        batch_sentences_embeddings.append(embeddings[ModalityType.TEXT])
        batch_videos_embeddings.append(embeddings[ModalityType.VISION])

    return batch_sentences_embeddings, batch_videos_embeddings

    # Create similarity matrix
    sim_matrix = create_sim_matrix(batch_sentences_embeddings, batch_videos_embeddings)

    # Log metrics
    print(f"MSRVTT sim matrix size: {sim_matrix.shape[0]}, {sim_matrix.shape[1]}")
    tv_metrics = compute_metrics(sim_matrix)
    vt_metrics = compute_metrics(sim_matrix.T)
    print('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))

    print(f"MSRVTT Text-to-Video:")
    print('\t>>>  R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'.
                format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR']))
    print(f"MSRVTT Video-to-Text:")
    print('\t>>>  V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'.
                format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR']))

def create_sim_matrix(batch_sentences_embeddings, batch_videos_embeddings):
    """Calculate embedding vector product for similarity and download result to CPU
    
        Returns: 
            sim_matrix (Text X Video)
    """
    sim_matrix = []
    for idx1 in range(len(batch_sentences_embeddings)):
        sequence_output = batch_sentences_embeddings[idx1]
        each_row = []
        for idx2 in range(len(batch_videos_embeddings)):
            visual_output = batch_videos_embeddings[idx2]
            b1b2 =  sequence_output @ visual_output.T
            b1b2 = b1b2.cpu().detach().numpy()
            each_row.append(b1b2)
        each_row = np.concatenate(tuple(each_row), axis=-1)
        sim_matrix.append(each_row)
    sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
    return sim_matrix

def compute_metrics(x):
    sx = np.sort(-x, axis=1)
    d = np.diag(-x)
    d = d[:, np.newaxis]
    ind = sx - d
    ind = np.where(ind == 0)
    ind = ind[1]
    metrics = {}
    metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind)
    metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind)
    metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind)
    metrics['MR'] = np.median(ind) + 1
    metrics["MedianR"] = metrics['MR']
    metrics["MeanR"] = np.mean(ind) + 1
    # metrics["cols"] = [int(i) for i in list(ind)]
    return metrics

def main():
    assert torch.cuda.is_available()
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    # Instantiate model
    model = imagebind_model.imagebind_huge(pretrained=True)
    model.eval()
    model.to(device)

    args = get_args_msrvtt()

    dataloader_msrvtt = DataLoader(
        MyMSRVTT_DataLoader(csv_path=args.val_csv, features_path=args.features_path),
        batch_size=args.batch_size_val,
        num_workers=args.num_thread_reader,
        shuffle=False,
        drop_last=False,
    )

    return run_msrvtt_eval(model, dataloader_msrvtt, device)


if __name__ == '__main__':
    batch_sentences_embeddings, batch_videos_embeddings = main()




  0%|          | 0/125 [00:00<?, ?it/s]

In [2]:
def calculate_rankings(batch_sentences_embeddings, batch_videos_embeddings):
    # Create similarity matrix
    sim_matrix = create_sim_matrix(batch_sentences_embeddings, batch_videos_embeddings)

    # Log metrics
    print(f"MSRVTT sim matrix size: {sim_matrix.shape[0]}, {sim_matrix.shape[1]}")
    tv_metrics = compute_metrics(sim_matrix)
    vt_metrics = compute_metrics(sim_matrix.T)
    print('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))

    print(f"MSRVTT Text-to-Video:")
    print('\t>>>  R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'.
                format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR']))
    print(f"MSRVTT Video-to-Text:")
    print('\t>>>  V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'.
                format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR']))

if __name__ == '__main__':
    calculate_rankings(batch_sentences_embeddings, batch_videos_embeddings)

MSRVTT sim matrix size: 1000, 1000
	 Length-T: 1000, Length-V:1000
MSRVTT Text-to-Video:
	>>>  R@1: 36.4 - R@5: 58.9 - R@10: 69.7 - Median R: 3.0 - Mean R: 29.0
MSRVTT Video-to-Text:
	>>>  V2T$R@1: 29.1 - V2T$R@5: 53.2 - V2T$R@10: 63.8 - V2T$Median R: 5.0 - V2T$Mean R: 34.1


MSRVTT sim matrix size: 1000, 1000\
	Length-T: 1000, Length-V:1000\
MSRVTT Text-to-Video:\
	>>>  R@1: 36.4 - R@5: 58.9 - R@10: 69.7 - Median R: 3.0 - Mean R: 29.0\
MSRVTT Video-to-Text:\
	>>>  V2T$R@1: 29.1 - V2T$R@5: 53.2 - V2T$R@10: 63.8 - V2T$Median R: 5.0 - V2T$Mean R: 34.1\
	