In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
from experimental_metrics import *

In [2]:
popularity_df = pd.read_csv("dataset/id_metadata_mmsr.tsv", sep="\t")
popularity = popularity_df['popularity'].values
infos = pd.read_csv("dataset/id_information_mmsr.tsv", sep="\t")
inter_true = np.loadtxt("./predictions/binary_relevancy_matrix_00.csv", delimiter="\t")
tags = pd.read_csv("./dataset/id_tags_dict.tsv", sep="\t")
genres = pd.read_csv("./dataset/id_genres_mmsr.tsv", sep="\t")

In [6]:
models = ["random", "bert", "bert_adv", "mfcc_bow", "mfcc_stats", "tf_idf", "item_item", "word2vec", "ivec", "vgg19", "resnet", "incp", "genre_tags"]
evaluation = []

for model in tqdm(models):
    inter_pred = np.loadtxt(f"predictions/recs_{model}_10.csv", delimiter="\t")
    pop_at_10 = avg_popularity_at_k(inter_pred, popularity, 10)
    cov_at_10 = avg_coverage_at_k(inter_pred, inter_true.shape[0], 10)
    ild_at_10 = intra_list_diversity_at_k(inter_pred, inter_true, 10)
    nov_at_10 = novelty_at_k(inter_pred, popularity, 10)
    div_at_10 = diversity_at_k(inter_pred, tags, genres, 10)

    evaluation.append({
        "model": model,
        "pop@10": pop_at_10,
        "cov@10": cov_at_10,
        "ild@10": ild_at_10,
        "nov@10": nov_at_10,
        "div@10": div_at_10
    })

evaluation_df = pd.DataFrame(evaluation)

100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [01:01<00:00,  4.73s/it]


In [4]:
evaluation_df

Unnamed: 0,model,pop@10,cov@10,ild@10,nov@10,div@10
0,random,35.145513,1.0,0.936407,-4.955561,29.000389
1,bert,37.322552,0.833139,0.875114,-5.06828,28.711344
2,bert_adv,37.308741,0.833722,0.875248,-5.067765,28.707071
3,mfcc_bow,36.578768,0.912005,0.859682,-5.03858,28.486597
4,mfcc_stats,36.220513,0.865579,0.865303,-5.020059,28.40812
5,tf_idf,36.128691,0.976496,0.904105,-5.003819,28.924437
6,item_item,34.416317,0.849456,0.920176,-4.909247,28.782634
7,word2vec,39.092036,0.761267,0.869084,-5.13493,28.648019
8,ivec,35.006119,0.99864,0.8976,-4.955263,28.68784
9,vgg19,36.168434,0.876457,0.897885,-5.018779,28.790016
