In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from metrics import *
from bert import *
from tf_idf import * 
from random_sample import *

In [2]:
infos = pd.read_csv("dataset/id_information_mmsr.tsv", sep="\t")
inter_true = np.loadtxt("./predictions/binary_relevancy_matrix_00.csv", delimiter="\t")

In [3]:
models = ["random", "bert", "mfcc_bow", "mfcc_stats", "tf_idf", "word2vec", "ivec", "vgg19", "resnet", "incp", "genre_tags"]
evaluation = []

for model in tqdm(models):
    inter_pred = np.loadtxt(f"predictions/recs_{model}_10.csv", delimiter="\t")
    rat10 = recall_at_k(inter_pred, inter_true, 10)
    pat10 = precision_at_k(inter_pred, inter_true, 10)
    ndcgat10 = ndcg_at_k(inter_pred, inter_true, 10)
    rat20 = recall_at_k(inter_pred, inter_true, 20)
    pat20 = precision_at_k(inter_pred, inter_true, 20)
    ndcgat20 = ndcg_at_k(inter_pred, inter_true, 20)
    mean_rr = mrr(inter_pred, inter_true)

    evaluation.append({
        "model": model,
        "recall@10": rat10,
        "precision@10": pat10,
        "ndcg@10": ndcgat10,
        "recall@20": rat20,
        "precision@20": pat20,
        "ndcg@20": ndcgat20,
        "mrr": mean_rr
    })

evaluation_df = pd.DataFrame(evaluation)

  0%|          | 0/11 [00:00<?, ?it/s]

In [4]:
evaluation_df

Unnamed: 0,model,recall@10,precision@10,ndcg@10,recall@20,precision@20,ndcg@20,mrr
0,random,0.002001,0.063326,0.063499,0.003621,0.058974,0.060342,0.001685
1,bert,0.00782,0.123932,0.126649,0.009493,0.089831,0.101855,0.002264
2,mfcc_bow,0.007672,0.142968,0.148872,0.009337,0.098737,0.115926,0.002488
3,mfcc_stats,0.005499,0.131605,0.135092,0.007095,0.093094,0.106694,0.002353
4,tf_idf,0.003936,0.093901,0.094866,0.005587,0.074223,0.080608,0.001972
5,word2vec,0.0059,0.103788,0.1064,0.007537,0.078963,0.088047,0.002088
6,ivec,0.00881,0.133683,0.142101,0.010515,0.094435,0.111974,0.002447
7,vgg19,0.005642,0.109421,0.120357,0.007394,0.082935,0.098101,0.002282
8,resnet,0.006341,0.121737,0.13338,0.007988,0.08816,0.105929,0.002399
9,incp,0.005388,0.113772,0.124578,0.007107,0.084606,0.100386,0.002318
