In [None]:
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import os
import h5py
# sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
from utils.utils import *
    
class NEXTQADataset(Dataset):
    def __init__(
        self,
        anno_path = '../nextqa/annotations_mc/train.csv',
        mapper_path = '../nextqa/map_vid_vidorID.json',
        video_path = "../nextqa/videos", 
        frame_path = "../nextqa/frames_32",
        feature_path = "../nextqa/vision_features/feats_wo_norm_32.h5",
        frame_count = 32
    ):
        
        self.data = load_csv(anno_path)
        self.mapper = load_json(mapper_path)
        self.video_path = video_path
        self.frame_path = frame_path
        self.frame_count = frame_count
        self.image_processor = image_transform(image_size=224)
        self.image_features = h5py.File(feature_path, "r")

        self.video_ids = []
        self.videos = []
        self.frames = []
        self.questions = []
        self.answers_option = []
        self.answers_text = []
        self.answers_ids = []
        self.types = []
        self.qids = []
        self.options_a0 = []
        self.options_a1 = []
        self.options_a2 = []
        self.options_a3 = []
        self.options_a4 = []

        for data in self.data:

            self.video_ids.append(data['video'])
            self.qids.append(data['qid'])
            self.types.append(data['type'])
            self.questions.append(data['question']+"?")
            self.options_a0.append(data['a0'])
            self.options_a1.append(data['a1'])
            self.options_a2.append(data['a2'])
            self.options_a3.append(data['a3'])
            self.options_a4.append(data['a4'])

            self.answers_ids.append(data['answer'])
            self.answers_text.append(data[f"a{str(data['answer'])}"] )
            self.answers_option.append(["A", "B", "C", "D", "E"][data['answer']])
            self.videos.append(self.video_path + f"/{self.mapper[str(data['video'])]}.mp4")
            self.frames.append(self.frame_path +f"/{str(data['video'])}")

    def __len__(self):
        """returns the length of dataframe"""
        return len(self.video_ids)

    def __getitem__(self, index):
        """return the input ids, attention masks and target ids"""
        video_id = str(self.video_ids[index])
        qid = str(self.qids[index])
        type = str(self.types[index])
        question = str(self.questions[index])
        option_a0 = str(self.options_a0[index])
        option_a1 = str(self.options_a1[index])
        option_a2 = str(self.options_a2[index])
        option_a3 = str(self.options_a3[index])
        option_a4 = str(self.options_a4[index])
        answer_id = self.answers_ids[index]
        answer_text = str(self.answers_text[index])
        answer_option = str(self.answers_option[index])

        frame_files = os.listdir(str(self.frames[index]))
        frame_files = sorted(frame_files, key=lambda x: int(x.split('_')[1].split('.')[0]))
        frame_files = get_frames(frame_files, self.frame_count)

        # frame_features = []
        # for i in range(len(frame_files)):
        #     frame_features.append(torch.from_numpy(self.image_features[f"{video_id}_{frame_files[i].replace('.jpg','')}"][:]))
        # frame_features = torch.stack(frame_features, dim=0) # [frame_count, 257, 1408]
        
        return {
                "video_ids": video_id,
                "qids": qid,
                "types": type,

                # "frame_features": frame_features,
                "frame_files": frame_files,

                "questions": question,
                "options_a0": option_a0,
                "options_a1": option_a1,
                "options_a2": option_a2,
                "options_a3": option_a3,
                "options_a4": option_a4,
                "answers_id": answer_id,
                "answers_text": answer_text,
                "answers": answer_option,

            }
    


In [None]:
"""각 QA pair 를 plot 으로 변환해서 ./temp/ 에 저장 (이미지 시퀀스 + QA 텍스트박스)"""

import glob
import os
import matplotlib.pyplot as plt
from PIL import Image
import tqdm.auto as tqdm
from torchvision import transforms


# Create a temporary directory to save plots
os.makedirs('./temp', exist_ok=True)

# train_dataset = NEXTQADataset(anno_path='../nextqa/annotations_mc/train.csv', frame_count=32)
train_dataset = NEXTQADataset(anno_path='../nextqa/annotations_mc/test.csv', frame_count=32)
image_transform = transforms.Compose([
        Resize(224, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(224),
    ])

def export_to_plot(sample):
    video_id = sample['video_ids']
    qa_text = (f"video_ids: {video_id}  qids: {sample['qids']}  types: {sample['types']}\n"
            f"Question: {sample['questions']}\n"
            f"Options: \n"
            f"  A: {sample['options_a0']}\n"
            f"  B: {sample['options_a1']}\n"
            f"  C: {sample['options_a2']}\n"
            f"  D: {sample['options_a3']}\n"
            f"  E: {sample['options_a4']}\n"
            f"Answer: {sample['answers']}")

    # Create a plot for the video frames
    fig, axes = plt.subplots(4, 8, figsize=(20, 12))  # Increase height to add space for text
    axes = axes.flatten()
    for ax, frame in zip(axes, sample['frame_files']):
        img = Image.open(f"../nextqa/frames_32/{video_id}/{frame}")
        img = image_transform(img)
        ax.imshow(img)
        ax.axis('off')
        # Add the filename as a title
        ax.set_title(frame, fontsize=8)

    # Add a textbox with info at the top left corner
    plt.gcf().text(0.02, 0.98, qa_text, fontsize=12, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.5))

    # Save the plot to the temporary directory
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to make space for the text
    plt.savefig(f'./temp/{video_id}_{sample["qids"]}.png')
    plt.close(fig)


# 병렬 처리
from joblib import Parallel, delayed

num_workers = 16
temp_paths = Parallel(num_workers)(delayed(export_to_plot)(x) for x in tqdm.tqdm(train_dataset, total=len(train_dataset)))

print("Plots saved to ./temp/")

In [None]:
"""fiftyone 시각화"""


import fiftyone as fo
import numpy as np
import pandas as pd
import tqdm.auto as tqdm
import os


# # Fiftyone Dataset 생성
dataset = fo.Dataset("nextqa", persistent=True, overwrite=True)
# dataset = fo.load_dataset("nextqa")

# train_dataset = NEXTQADataset(anno_path='../nextqa/annotations_mc/train.csv', frame_count=32)
test_dataset = NEXTQADataset(anno_path='../nextqa/annotations_mc/test.csv', frame_count=32)
acc_records = torch.load('acc_records.pth', map_location="cpu")
acc_records = pd.DataFrame(acc_records).set_index(['video_id', 'qid'])


sample_list = []
for x in tqdm.tqdm(test_dataset):
    video_id = x['video_ids']
    types = x['types']
    qids = x['qids']
    file = f'./temp/{video_id}_{qids}.png'
    label = acc_records.loc[video_id, qids]['label']
    pred = acc_records.loc[video_id, qids]['pred']
    confidence = acc_records.loc[video_id, qids]['sequences_scores'].item()

    sample = fo.Sample(filepath=file)
    sample['label'] = fo.Classification(label=label)
    sample['pred'] = fo.Classification(label=pred, confidence=confidence)
    sample["split"] = fo.Classification(label="train")
    sample['video_ids'] = int(video_id)
    sample['qids'] = int(qids)
    sample['types'] = fo.Classification(label=types)
    sample_list.append(sample)

_ = dataset.add_samples(sample_list)



In [None]:

"""FiftyOne 앱 실행"""

import fiftyone as fo
dataset = fo.load_dataset("nextqa")

# VSCode port forwarding 으로 5152를 포워딩한 후 웹브라우저에서 http://localhost:5152/ 접속
session = fo.launch_app(dataset, auto=False, port=5152)


In [None]:
session.view = dataset.match(fo.ViewField("label.label") != fo.ViewField("pred.label"))

In [None]:
session.close()