In [1]:
import sys
sys.path.append("../")

In [2]:
import yaml
import torch
import pickle
from torch.utils.data import DataLoader
from visdial.data.dataset import VisDialDataset
from visdial.metrics import SparseGTMetrics, NDCG, scores_to_ranks

# For Validation

In [5]:
import json
config_path = '/home/quang/checkpoints/abci/s41/config.json'
split = 'val'

with open(config_path) as file:
    config = json.load(file)


eval_dataset = VisDialDataset(config, split='val')

eval_dataloader = DataLoader(eval_dataset,
                            batch_size=1)

sparse_metrics = SparseGTMetrics()
ndcg = NDCG()

round_ids = []
im_ids = []
ans_inds = []
gt_relevances = []

for batch in eval_dataloader:
    round_ids.append(batch['round_id'])
    im_ids.append(batch['img_ids'])
    if split == 'val':
        ans_inds.append(batch['ans_ind'])
        gt_relevances.append(batch["gt_relevance"])

round_ids = torch.stack(round_ids, dim=0).view(-1)
im_ids = torch.stack(im_ids, dim=0).view(-1)
ans_inds = torch.stack(ans_inds, dim=0).view(-1, 10)
gt_relevances = torch.stack(gt_relevances, dim=0).view(-1, 100)


[val2018] Tokenizing questions...
[val2018] Tokenizing answers...
[val2018] Tokenizing captions...
genome_path None


# Compute the Ensemble for Test

In [45]:
split = 'test'
model_indices = [13, 15, 19, 21, 22, 23, 24,25,26]
model_indices = [str(idx) for idx in model_indices]
pickle_paths = ['/home/quang/checkpoints/s{0}/ranks/test/no_ft_ckpt_4/disc.pkl'.format(idx) for idx in model_indices]
rank_output_path = '/home/quang/checkpoints/ranks/test/no_ft_ckpt_4/ensemble_{}.json'.format("_".join(model_indices))

all_outputs, all_img_ids, all_round_ids, all_ans_ids, all_gt_relevance = [], [], [], [], []

for pickle_path in pickle_paths:
    with open(pickle_path, 'rb') as f:
        x = pickle.load(f)
        outputs, img_ids, round_ids = x
        img_ids = torch.stack(img_ids, dim=0).view(-1)
        outputs = torch.stack(outputs, dim=0).view(-1, 1, 100)
        round_ids = torch.stack(round_ids, dim=0).view(-1)
        all_outputs.append(outputs)
        all_img_ids.append(img_ids)
        all_round_ids.append(round_ids)

avg_output = torch.zeros_like(all_outputs[0])
# print(avg_output.shape)

for out in all_outputs:
    avg_output += out 
avg_output /= float(len(all_outputs))

ranks = scores_to_ranks(avg_output)

num_rounds = all_round_ids[0]
ranks_json = []

for i in range(len(img_ids)):
    if split == 'test':
        ranks_json.append(
                {
                    "image_id": img_ids[i].item(),
                    "round_id": int(num_rounds[i].item()),
                    "ranks"   : [rank.item() for rank in ranks[i][0]],
                    }
                )

os.makedirs(os.path.dirname(rank_output_path), exist_ok=True)
json.dump(ranks_json, open(rank_output_path, "w"))

print(rank_output_path)    

/home/quang/checkpoints/s11/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/s13/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/s19/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/s21/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/s22/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/s23/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/s24/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/s26/finetune/lr_5e-05/CosineLR/ranks/val/ckpt_4/checkpoint_11.pth_disc.pkl
/home/quang/checkpoints/ranks/val/no_ft_ckpt_4/ensemble_11_13_19_21_22_23_24_26.json


In [43]:
ls /home/quang/checkpoints/s11/finetune/lr_5e-05/CosineLR/ranks/val/ft_ckpt_4/

ls: cannot access '/home/quang/checkpoints/s11/finetune/lr_5e-05/CosineLR/ranks/val/ft_ckpt_4/': No such file or directory


# Compute the Ensemble Val

In [59]:
split = 'val'
model_indices = [11, 13, 19, 21, 22, 23, 24, 26]
model_indices = [str(idx) for idx in model_indices]
pickle_paths = ['/home/quang/checkpoints/s{0}/finetune/lr_5e-05/CosineLR/ranks/{1}/ckpt_3/checkpoint_11.pth_disc.pkl'.format(idx, split) for idx in model_indices]
rank_output_path = '/home/quang/checkpoints/ranks/{}/no_ft_ckpt_4/ensemble_{}.json'.format(split, "_".join(model_indices))
pickle_paths = ['/home/quang/checkpoints/s25/ranks/val/no_ft_ckpt_4/disc.pkl']
    
all_outputs, all_img_ids, all_round_ids, all_ans_ids, all_gt_relevance = [], [], [], [], []

for pickle_path in pickle_paths:
    print(pickle_path)
    with open(pickle_path, 'rb') as f:
        x = pickle.load(f)
        outputs, img_ids, _ = x
        img_ids = torch.stack(img_ids, dim=0).view(-1)
        outputs = torch.stack(outputs, dim=0).view(-1, 10, 100)
        all_outputs.append(outputs)
        all_img_ids.append(img_ids)
        all_round_ids.append(round_ids)


for im_id, img_id in zip(im_ids, img_ids):
    if im_id != img_id:
        print("False")

avg_output = torch.zeros_like(all_outputs[0])
# print(avg_output.shape)

for out in all_outputs:
    avg_output += out 
avg_output /= float(len(all_outputs))


print("compute scores...")
sparse_metrics.observe(avg_output, ans_inds)

rel_output = avg_output[torch.arange(avg_output.size(0)), round_ids - 1, :]
ndcg.observe(rel_output, gt_relevances)

all_metrics = {}
all_metrics.update(sparse_metrics.retrieve(reset=True))
all_metrics.update(ndcg.retrieve(reset=True))
val_keys = ['ndcg', 'mrr', 'r@1', 'r@5', 'r@10', 'mean']
for key in val_keys:
    print(key, end=',')
print()
for key in val_keys:
    print(all_metrics[key], end=',')

    

/home/quang/checkpoints/s25/ranks/val/no_ft_ckpt_4/disc.pkl
compute scores...
ndcg,mrr,r@1,r@5,r@10,mean,
0.5952447056770325,0.6482028365135193,0.5139050483703613,0.8146317601203918,0.9053294658660889,4.00227689743042,

In [6]:
split = 'val'
sparse_metrics = SparseGTMetrics()
ndcg = NDCG()


pickle_paths = [
    '/home/quang/checkpoints/abci/s41/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_13.pth_disc.pkl',
    '/home/quang/checkpoints/abci/s42/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl',
    '/home/quang/checkpoints/abci/s44/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl',
    '/home/quang/checkpoints/abci/s45/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl',
    '/home/quang/checkpoints/abci/n41/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl',
    '/home/quang/checkpoints/abci/n42/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl',
]
all_outputs, all_img_ids, all_round_ids, all_ans_ids, all_gt_relevance = [], [], [], [], []

for pickle_path in pickle_paths:
    print(pickle_path)
    with open(pickle_path, 'rb') as f:
        x = pickle.load(f)
        outputs, img_ids, round_ids = x
        img_ids = torch.stack(img_ids, dim=0).view(-1)
        outputs = torch.stack(outputs, dim=0).view(-1, 10, 100)
        round_ids = torch.stack(round_ids, dim=0).view(-1)
        all_outputs.append(outputs)
        all_img_ids.append(img_ids)
        all_round_ids.append(round_ids)


avg_output = torch.zeros_like(all_outputs[0])

for out in all_outputs:
    avg_output += out 
avg_output /= float(len(all_outputs))


print("compute scores...")
sparse_metrics.observe(avg_output, ans_inds)

rel_output = avg_output[torch.arange(avg_output.size(0)), round_ids - 1, :]
ndcg.observe(rel_output, gt_relevances)

all_metrics = {}
all_metrics.update(sparse_metrics.retrieve(reset=True))
all_metrics.update(ndcg.retrieve(reset=True))
val_keys = ['ndcg', 'mrr', 'r@1', 'r@5', 'r@10', 'mean']
for key in val_keys:
    print(key, end=',')
print()
for key in val_keys:
    print(all_metrics[key], end=',')

/home/quang/checkpoints/abci/s41/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_13.pth_disc.pkl
/home/quang/checkpoints/abci/s42/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl
/home/quang/checkpoints/abci/s44/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl
/home/quang/checkpoints/abci/s45/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl
/home/quang/checkpoints/abci/n41/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl
/home/quang/checkpoints/abci/n42/finetune/lr_1e-05/CosineLR/ranks/val/ckpt_0/checkpoint_14.pth_disc.pkl
compute scores...
ndcg,mrr,r@1,r@5,r@10,mean,
0.6778573989868164,0.6218511462211609,0.4948643445968628,0.7744670510292053,0.8721414804458618,4.906104564666748,

In [None]:
split = 'val'
model_indices = [11, 13, 19, 21, 22, 23, 24, 26]
model_indices = [str(idx) for idx in model_indices]
pickle_paths = ['/home/quang/checkpoints/s{0}/finetune/lr_5e-05/CosineLR/ranks/{1}/ckpt_3/checkpoint_11.pth_disc.pkl'.format(idx, split) for idx in model_indices]
rank_output_path = '/home/quang/checkpoints/ranks/{}/no_ft_ckpt_4/ensemble_{}.json'.format(split, "_".join(model_indices))
pickle_paths = ['/home/quang/checkpoints/s25/ranks/val/no_ft_ckpt_4/disc.pkl']
    
all_outputs, all_img_ids, all_round_ids, all_ans_ids, all_gt_relevance = [], [], [], [], []

for pickle_path in pickle_paths:
    print(pickle_path)
    with open(pickle_path, 'rb') as f:
        x = pickle.load(f)
        outputs, img_ids, _ = x
        img_ids = torch.stack(img_ids, dim=0).view(-1)
        outputs = torch.stack(outputs, dim=0).view(-1, 10, 100)
        all_outputs.append(outputs)
        all_img_ids.append(img_ids)
        all_round_ids.append(round_ids)


for im_id, img_id in zip(im_ids, img_ids):
    if im_id != img_id:
        print("False")

avg_output = torch.zeros_like(all_outputs[0])
# print(avg_output.shape)

for out in all_outputs:
    avg_output += out 
avg_output /= float(len(all_outputs))


print("compute scores...")
sparse_metrics.observe(avg_output, ans_inds)

rel_output = avg_output[torch.arange(avg_output.size(0)), round_ids - 1, :]
ndcg.observe(rel_output, gt_relevances)

all_metrics = {}
all_metrics.update(sparse_metrics.retrieve(reset=True))
all_metrics.update(ndcg.retrieve(reset=True))
val_keys = ['ndcg', 'mrr', 'r@1', 'r@5', 'r@10', 'mean']
for key in val_keys:
    print(key, end=',')
print()
for key in val_keys:
    print(all_metrics[key], end=',')

In [26]:
print(rank_output_path)

/home/quang/checkpoints/ranks/val/no_ft_ckpt_4/ensemble_19_22_23_24_25_26.json


# Compute the Ensemble

In [195]:
for ckpt in range(5, 10):
    pickle_paths =['/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s06_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/val/ckpt_{}/checkpoint_11.pth_{}.pkl', 
               '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s07_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/val/ckpt_{}/checkpoint_11.pth_{}.pkl', 
               '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s08_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/val/ckpt_{}/checkpoint_11.pth_{}.pkl', 
               '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s09_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/val/ckpt_{}/checkpoint_last.pth_{}.pkl', 
               '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s10_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/val/ckpt_{}/checkpoint_11.pth_{}.pkl', 
               '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s11_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/val/ckpt_{}/checkpoint_11.pth_{}.pkl', 
                 ]

    pickle_paths = [p.format(lr, str(ckpt), decoder) for p in pickle_paths]
    rank_output_path = '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/ranks/val/ensemble_s67891011_ckpt_2_ft.json'
    
    
    all_outputs, all_img_ids, all_round_ids, all_ans_ids   , all_gt_relevance = [], [], [], [], []

    for pickle_path in pickle_paths:
        with open(pickle_path, 'rb') as f:
            x = pickle.load(f)
            outputs, img_ids, _ = x
            img_ids = torch.stack(img_ids, dim=0).view(-1)
            outputs = torch.stack(outputs, dim=0).view(-1, 10, 100)
            all_outputs.append(outputs)
            all_img_ids.append(img_ids)
            all_round_ids.append(round_ids)


    for im_id, img_id in zip(im_ids, img_ids):
        if im_id != img_id:
            print("False")

    avg_output = torch.zeros_like(all_outputs[0])
    # print(avg_output.shape)

    for out in all_outputs:
        avg_output += out 
    avg_output /= float(len(all_outputs))



    sparse_metrics.observe(avg_output, ans_inds)

    rel_output = avg_output[torch.arange(avg_output.size(0)), round_ids - 1, :]
    ndcg.observe(rel_output, gt_relevances)

    all_metrics = {}
    all_metrics.update(sparse_metrics.retrieve(reset=True))
    all_metrics.update(ndcg.retrieve(reset=True))
    val_keys = ['ndcg', 'mrr', 'r@1', 'r@5', 'r@10', 'mean']
    print(f"{ckpt + 1}", end=',')
    
    for key in val_keys:
        print(all_metrics[key], end=',')
    print()

6,0.8491790890693665,0.5173311233520508,0.3804748058319092,0.6699128150939941,0.7983042597770691,6.766472816467285,
7,0.8558176159858704,0.5051973462104797,0.36419573426246643,0.6630814075469971,0.7957364320755005,6.8550872802734375,
8,0.8621454834938049,0.5007437467575073,0.35857558250427246,0.6618701815605164,0.7950581312179565,6.876114368438721,
9,0.8676555156707764,0.4963180124759674,0.3525193929672241,0.6590600609779358,0.7939438223838806,6.899176120758057,
10,0.8723283410072327,0.49351394176483154,0.348982572555542,0.6577519178390503,0.7942345142364502,6.914486408233643,


# Saving

In [147]:
ranks = scores_to_ranks(avg_output)

ranks_json = []

for i in range(len(img_ids)):
    if split == 'val':
        for j in range(10):
            ranks_json.append(
                    {
                        "image_id": img_ids[i].item(),
                        "round_id": int(j + 1),
                        "ranks"   : [rank.item() for rank in ranks[i][j]],
                        }
                    )
            
import os
import json

os.makedirs(os.path.dirname(rank_output_path), exist_ok=True)
json.dump(ranks_json, open(rank_output_path, "w"))

# For Test


In [54]:
import torch
import pickle
import os
import json

lr = '5'
split = 'test'

# for ckpt in [3, 2, 0]:
print("ckpt", ckpt)

# pickle_paths =[
#            '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s06_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_{}/disc.pkl', 
#            '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s07_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_{}/disc.pkl', 
#            '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s08_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_{}/disc.pkl', 
#            '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s09_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_{}/disc.pkl', 
#            '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s10_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_{}/disc.pkl', 
#            '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s11_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_{}/disc.pkl', 
#              ]

# pickle_paths = [p.format(lr, str(ckpt)) for p in pickle_paths]
# rank_output_path = '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/ranks/test/ensemble_s67891011_ckpt_{}_ft.json'.format(ckpt)

pickle_paths =[
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s06_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_3/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s07_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_3/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s08_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_3/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s09_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_3/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s10_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_3/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s11_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_3/disc.pkl',
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s06_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_4/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s07_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_4/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s08_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_4/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s09_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_4/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s10_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_4/disc.pkl', 
           '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/s11_simple_branch/finetune/lr_{}e-05/CosineLR/ranks/test/ckpt_4/disc.pkl'   
]   
    
pickle_paths = [p.format(lr) for p in pickle_paths]
rank_output_path = '/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/ranks/test/ensemble_s67891011_ckpt_3_4_ft.json'


all_outputs, all_img_ids, all_round_ids, all_ans_ids, all_gt_relevance = [], [], [], [], []

for pickle_path in pickle_paths:
    with open(pickle_path, 'rb') as f:
        x = pickle.load(f)
        outputs, img_ids, round_ids = x
        img_ids = torch.stack(img_ids, dim=0).view(-1)
        outputs = torch.stack(outputs, dim=0).view(-1, 1, 100)
        round_ids = torch.stack(round_ids, dim=0).view(-1)
        all_outputs.append(outputs)
        all_img_ids.append(img_ids)
        all_round_ids.append(round_ids)

avg_output = torch.zeros_like(all_outputs[0])
# print(avg_output.shape)

for out in all_outputs:
    avg_output += out 
avg_output /= float(len(all_outputs))

ranks = scores_to_ranks(avg_output)

num_rounds = all_round_ids[0]
ranks_json = []

for i in range(len(img_ids)):
    if split == 'test':
        ranks_json.append(
                {
                    "image_id": img_ids[i].item(),
                    "round_id": int(num_rounds[i].item()),
                    "ranks"   : [rank.item() for rank in ranks[i][0]],
                    }
                )

os.makedirs(os.path.dirname(rank_output_path), exist_ok=True)
json.dump(ranks_json, open(rank_output_path, "w"))

print(rank_output_path)    

ckpt 0
/media/local_workspace/quang/checkpoints/visdial/CVPR/train_simple/lr001/12_epochs/ranks/test/ensemble_s67891011_ckpt_3_4_ft.json


# Old scripts 

In [27]:
all_outputs, all_img_ids, all_round_ids, all_ans_ids, all_gt_relevance = [], [], [], [], []

for pickle_path in pickle_paths:
    with open(pickle_path, 'rb') as f:
        x = pickle.load(f)
        outputs, img_ids, round_ids = x
        img_ids = torch.stack(img_ids, dim=0).view(-1)
        round_ids = torch.stack(round_ids, dim=0).view(-1)
        outputs = torch.stack(outputs, dim=0).view(-1, 10, 100)
        all_outputs.append(outputs)
        all_img_ids.append(img_ids)
        all_round_ids.append(round_ids)


ranks_json = []

for k in range(len(all_outputs[0])):
    
    img_ids = all_img_ids[0][k]
    num_rounds = all_round_ids[0][k]
    
    output = torch.zeros_like(all_outputs[0][k])
    
    for kind in range(len(all_outputs)):
        output_kind = all_outputs[kind][k]
        output += output_kind
    output = output / float(len(all_outputs))
    
    ranks = scores_to_ranks(output)
    for i in range(len(img_ids)):
        
        if split == 'test':
            ranks_json.append(
            {
                "image_id": img_ids[i].item(),
                "round_id": int(num_rounds[i].item()),
                "ranks": [rank.item() for rank in ranks[i][0]  # [batch["num_rounds"][i] - 1]
                ],
            })
        elif split == 'val':
            for j in range(num_rounds[i]):

                ranks_json.append(
                        {
                            "image_id": img_ids[i].item(),
                            "round_id": int(j + 1),
                            "ranks"   : [rank.item() for rank in ranks[i][j]],
                            }
                        )

In [28]:
import json
json.dump(ranks_json, open(rank_output_path, "w"))

In [29]:
print(rank_output_path)

/home/quang/ranks/train_simple/test/ensemble_s6_7_7cosine_disc_ckpt_3_ft.json


In [33]:
import os
os.path.basename(rank_output_path)

'ensemble_s6_7_7cosine_misc.json'

In [9]:
keys = ["NDCG (x 100)", "MRR (x 100)", "R@1", "R@5", "R@10", "Mean"]

In [60]:
out = {"test-std": {"MRR (x 100)": 64.0807984826235, "R@1": 50.2, "R@5": 80.675, "R@10": 90.35, "Mean": 4.052, "NDCG (x 100)": 59.032664470283706}}

In [61]:
if out.get('test-std') is not None:
    res = out['test-std']
else:
    res = out['val']

In [63]:
for key in keys:
    print(key, end=',')
print()
for key in keys:
    print(res[key]/100, end=',')

NDCG (x 100),MRR (x 100),R@1,R@5,R@10,Mean,
0.590326644702837,0.640807984826235,0.502,0.80675,0.9035,0.040519999999999994,

In [47]:
71.95 - 64.47

7.480000000000004