In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.chdir('/data/ly/code/LinVQATools')

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.video_mae_vqa import VideoMAEVQAWrapper
from data.default_dataset import SingleBranchDataset



加载模型和权重

In [2]:

model = VideoMAEVQAWrapper(
    model_type='s',
    mask_ratio=0.75,
    head_dropout=0.1,
    drop_path_rate=0.1
)
weight_path = '/data/ly/code/LinVQATools/work_dir/video_mae_vqa/01171449 vit random_cell_mask_75 mae last 4clip/best_SROCC_epoch_555.pth'
weight = torch.load(weight_path,map_location="cpu")
info = model.load_state_dict(weight['state_dict'])
print(info)

_IncompatibleKeys(missing_keys=['model.mean', 'model.std', 'model.mask_token', 'model.pos_embed', 'model.backbone.norm.weight', 'model.backbone.norm.bias', 'model.vqa_head.norm.weight', 'model.vqa_head.norm.bias', 'model.vqa_head.fc_hid.1.weight', 'model.vqa_head.fc_hid.1.bias', 'model.vqa_head.fc_last.0.weight', 'model.vqa_head.fc_last.0.bias', 'model.encoder_to_decoder.weight'], unexpected_keys=['model.backbone.fc_norm.weight', 'model.backbone.fc_norm.bias', 'model.backbone.head.weight', 'model.backbone.head.bias'])
<All keys matched successfully>


加载验证集

In [3]:
num_workers = 6
prefix = '4frame'
argument = [
        dict(
            name='FragmentShuffler',
            fragment_size=32,
            frame_cube=4
        ),
        dict(
            name='PostProcessSampler',
            frame_cube=4,
            num=4
        )
]
val_video_loader = dict(
    name='FragmentLoader',
    prefix=prefix,
    frame_sampler=None,
    spatial_sampler=None,
    argument=argument,
    phase='test',
    use_preprocess=True,
)
dataset=SingleBranchDataset(
    video_loader=val_video_loader,
    anno_root='./data/odv_vqa',
    anno_reader='ODVVQAReader',
    split_file='./data/odv_vqa/tr_te_VQA_ODV.txt',
    phase='test',
    norm=True,
    clip=4
)
val_dataloader = DataLoader(batch_size=1,
                            shuffle=False,
                            dataset=dataset)

In [12]:
gt = []
pr = []
model = model.cuda().eval()
with torch.no_grad():
    for item in tqdm(val_dataloader):
        gt.append(item['gt_label'])
        y = model(inputs=item["inputs"].cuda(), gt_label=item['gt_label'].cuda(),mode='predict')
        pr.append(y[0])
    # print(i)

100%|██████████| 108/108 [00:44<00:00,  2.44it/s]


In [13]:
from scipy.stats import spearmanr
import copy

all_gt = copy.deepcopy(gt)
all_pr = copy.deepcopy(pr)
all_gt = torch.tensor(all_gt)
all_pr = torch.tensor(all_pr)
all_srocc = spearmanr(all_gt, all_pr)[0]
print(all_srocc)
srocc_list = []
for i in range(12):
    part_pr = all_pr[i*9:(i+1)*9]
    part_gt = all_gt[i*9:(i+1)*9]
    srocc = spearmanr(part_gt, part_pr)[0]
    srocc_list.append(srocc)
print(srocc_list)

# 3 0.83, 4 0.9, 12 0.83, 17 0.84, 19 0.61, 21 0.9, 23 0.73, 28 0.78, 30 0.93, 39 0.93, 40 0.91, 58 0.93

0.8921851629559766
[0.8666666666666667, 0.9833333333333333, 0.7999999999999999, 0.7999999999999999, 0.8666666666666667, 0.9166666666666666, 0.8833333333333333, 0.8499999999999999, 0.8833333333333333, 0.9166666666666666, 0.9, 0.9500000000000001]


In [6]:
num_workers = 6
prefix = '4frame'
argument = [
        dict(
            name='FragmentShuffler',
            fragment_size=32,
            frame_cube=4
        ),
        dict(
            name='PostProcessSampler',
            frame_cube=4,
            num=4
        )
]
train_video_loader = dict(
    name='FragmentLoader',
    prefix=prefix,
    frame_sampler=None,
    spatial_sampler=None,
    argument=argument,
    phase='train',
    use_preprocess=True,
)
dataset=SingleBranchDataset(
    video_loader=train_video_loader,
    anno_root='./data/odv_vqa',
    anno_reader='ODVVQAReader',
    split_file='./data/odv_vqa/tr_te_VQA_ODV.txt',
    phase='train',
    norm=True,
    clip=4
)
train_dataloader = DataLoader(batch_size=1,
                            shuffle=False,
                            dataset=dataset)
gt = []
pr = []
model = model.cuda().eval()
with torch.no_grad():
    for item in tqdm(train_dataloader):
        gt.append(item['gt_label'])
        y = model(inputs=item["inputs"].cuda(), gt_label=item['gt_label'].cuda(),mode='predict')
        pr.append(y[0])

 13%|█▎        | 56/432 [01:00<06:49,  1.09s/it]


KeyboardInterrupt: 

In [None]:
from scipy.stats import spearmanr
import copy

all_gt = copy.deepcopy(gt)
all_pr = copy.deepcopy(pr)
all_gt = torch.tensor(all_gt)
all_pr = torch.tensor(all_pr)
all_srocc = spearmanr(all_gt, all_pr)[0]
print(all_srocc)
srocc_list = []
for i in range(48):
    part_pr = all_pr[i*9:(i+1)*9]
    part_gt = all_gt[i*9:(i+1)*9]
    srocc = spearmanr(part_gt, part_pr)[0]
    srocc_list.append(srocc)
print(srocc_list)