In [1]:
cd "/content/drive/My Drive/Colab Notebooks/테스트용"

/content/drive/My Drive/Colab Notebooks/테스트용


In [2]:
!pip install fire

Collecting fire
[?25l  Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)
[K     |████                            | 10kB 15.1MB/s eta 0:00:01[K     |████████                        | 20kB 8.5MB/s eta 0:00:01[K     |████████████▏                   | 30kB 7.3MB/s eta 0:00:01[K     |████████████████▏               | 40kB 6.6MB/s eta 0:00:01[K     |████████████████████▏           | 51kB 4.1MB/s eta 0:00:01[K     |████████████████████████▎       | 61kB 4.7MB/s eta 0:00:01[K     |████████████████████████████▎   | 71kB 4.6MB/s eta 0:00:01[K     |████████████████████████████████| 81kB 3.0MB/s 
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=2b0c400b315943709876a66ed3e9775765c24e3e0bba2df3fbe9df52a4a25b9b
  Stored in directory: /root/.cache/pip/whe

In [3]:
import os
import json

import fire
import numpy as np
from tqdm import tqdm


In [4]:
from arena_util import load_json

In [5]:
class ArenaEvaluator:
    def _idcg(self, l):
        return sum((1.0 / np.log(i + 2) for i in range(l)))

    def __init__(self):
        self._idcgs = [self._idcg(i) for i in range(101)]
        with open(os.path.join('/content/drive/My Drive/Colab Notebooks/테스트용/test_data/song_meta.json'), encoding="utf-8") as f:
            self.song_meta = json.load(f)

    def _ndcg(self, gt, rec):
        dcg = 0.0
        for i, r in enumerate(rec):
            if r in gt:
                dcg += 1.0 / np.log(i + 2)

        return dcg / self._idcgs[len(gt)]

    def _eval(self, gt_fname, rec_fname):
        gt_playlists = load_json(gt_fname)
        gt_dict = {g["id"]: g for g in gt_playlists}
        rec_playlists = load_json(rec_fname)

        gt_ids = set([g["id"] for g in gt_playlists])
        rec_ids = set([r["id"] for r in rec_playlists])

        if gt_ids != rec_ids:
            raise Exception("결과의 플레이리스트 수가 올바르지 않습니다.")

        rec_song_counts = [len(p["songs"]) for p in rec_playlists]
        rec_tag_counts = [len(p["tags"]) for p in rec_playlists]

        if set(rec_song_counts) != set([100]):
            raise Exception("추천 곡 결과의 개수가 맞지 않습니다.")

        if set(rec_tag_counts) != set([10]):
            raise Exception("추천 태그 결과의 개수가 맞지 않습니다.")

        rec_unique_song_counts = [len(set(p["songs"])) for p in rec_playlists]
        rec_unique_tag_counts = [len(set(p["tags"])) for p in rec_playlists]

        if set(rec_unique_song_counts) != set([100]):
            raise Exception("한 플레이리스트에 중복된 곡 추천은 허용되지 않습니다.")

        if set(rec_unique_tag_counts) != set([10]):
            raise Exception("한 플레이리스트에 중복된 태그 추천은 허용되지 않습니다.")

        music_ndcg = 0.0
        tag_ndcg = 0.0

        for rec in rec_playlists:
            gt = gt_dict[rec["id"]]
            music_ndcg += self._ndcg(gt["songs"], rec["songs"][:100])
            tag_ndcg += self._ndcg(gt["tags"], rec["tags"][:10])

        music_ndcg = music_ndcg / len(rec_playlists)
        tag_ndcg = tag_ndcg / len(rec_playlists)
        score = music_ndcg * 0.85 + tag_ndcg * 0.15

        return music_ndcg, tag_ndcg, score

    def _eval_detail(self, gt_fname, rec_fname, qt_fname):
        gt_playlists = load_json(gt_fname)
        gt_dict = {g["id"]: g for g in gt_playlists}
        rec_playlists = load_json(rec_fname)
        qt_playlists = load_json(qt_fname)
        qt_dict = {q["id"]: q for q in qt_playlists}

        #태그 X 곡 O
        TNSY_music_ndcg = 0.0
        TNSY_tag_ndcg = 0.0
        TNSY_count = 0
        TNSY_title_y = 0
        TNSY_title_n = 0

        #태그 O 곡 X
        TYSN_music_ndcg = 0.0
        TYSN_tag_ndcg = 0.0
        TYSN_count = 0
        TYSN_title_y = 0
        TYSN_title_n = 0

        #태그 O 곡 O
        TYSY_music_ndcg = 0.0
        TYSY_tag_ndcg = 0.0
        TYSY_count = 0
        TYSY_title_y = 0
        TYSY_title_n = 0

        #태그 X 곡 X
        TNSN_music_ndcg = 0.0
        TNSN_tag_ndcg = 0.0
        TNSN_count = 0
        TNSN_title_y = 0
        TNSN_title_n = 0

        for rec in tqdm(rec_playlists):
            qts = qt_dict[rec["id"]]
            gt = gt_dict[rec["id"]]
            #태그 X 곡 X
            if len(qts["tags"]) == 0 and len(qts["songs"]) == 0:
                if qts["plylst_title"]:
                    TNSN_title_y += 1
                else:
                    TNSN_title_n += 1
                TNSN_music_ndcg += self._ndcg(gt["songs"], rec["songs"][:100])
                TNSN_tag_ndcg += self._ndcg(gt["tags"], rec["tags"][:10])
                TNSN_count += 1

            #태그 O 곡 X
            if len(qts["tags"]) > 0 and len(qts["songs"]) == 0:
                if qts["plylst_title"]:
                    TYSN_title_y += 1
                else:
                    TYSN_title_n += 1
                TYSN_music_ndcg += self._ndcg(gt["songs"], rec["songs"][:100])
                TYSN_tag_ndcg += self._ndcg(gt["tags"], rec["tags"][:10])
                TYSN_count += 1

            #태그 X 곡 O
            if len(qts["tags"]) == 0 and len(qts["songs"]) > 0:
                if qts["plylst_title"]:
                    TNSY_title_y += 1
                else:
                    TNSY_title_n += 1
                TNSY_music_ndcg += self._ndcg(gt["songs"], rec["songs"][:100])
                TNSY_tag_ndcg += self._ndcg(gt["tags"], rec["tags"][:10])
                TNSY_count += 1
            
            #태그 O 곡 O
            if len(qts["tags"]) > 0 and len(qts["songs"]) > 0:
                if qts["plylst_title"]:
                    TYSY_title_y += 1
                else:
                    TYSY_title_n += 1
                TYSY_music_ndcg += self._ndcg(gt["songs"], rec["songs"][:100])
                TYSY_tag_ndcg += self._ndcg(gt["tags"], rec["tags"][:10])
                TYSY_count += 1

        #태그 X 곡 O
        TNSY_music_ndcg = TNSY_music_ndcg / TNSY_count
        TNSY_tag_ndcg = TNSY_tag_ndcg / TNSY_count
        TNSY_score = TNSY_music_ndcg * 0.85 + TNSY_tag_ndcg * 0.15
        #태그 O 곡 X
        TYSN_music_ndcg = TYSN_music_ndcg / TYSN_count
        TYSN_tag_ndcg = TYSN_tag_ndcg / TYSN_count
        TYSN_score = TYSN_music_ndcg * 0.85 + TYSN_tag_ndcg * 0.15
        #태그 O 곡 O
        TYSY_music_ndcg = TYSY_music_ndcg / TYSY_count
        TYSY_tag_ndcg = TYSY_tag_ndcg / TYSY_count
        TYSY_score = TYSY_music_ndcg * 0.85 + TYSY_tag_ndcg * 0.15
        #태그 X 곡 X
        TNSN_music_ndcg = TNSN_music_ndcg / TNSN_count
        TNSN_tag_ndcg = TNSN_tag_ndcg / TNSN_count
        TNSN_score = TNSN_music_ndcg * 0.85 + TNSN_tag_ndcg * 0.15

        return TNSY_music_ndcg, TNSY_tag_ndcg, TNSY_score, TYSN_music_ndcg, TYSN_tag_ndcg, TYSN_score, TYSY_music_ndcg, TYSY_tag_ndcg, TYSY_score, TNSN_music_ndcg, TNSN_tag_ndcg, TNSN_score, TNSY_title_y, TNSY_title_n, TYSN_title_y, TYSN_title_n, TYSY_title_y, TYSY_title_n, TNSN_title_y, TNSN_title_n

    def evaluate(self, gt_fname, rec_fname, qt_fname):
        try:
            music_ndcg, tag_ndcg, score = self._eval(gt_fname, rec_fname)
            #Literal String Interpolation. 간단히 줄여서 f-string. 가독성 문제를 해결해준다.
            #f-string를 사용한 문자열에는 중괄호 {} 를 이용해서 다양한 표현식을 사용할 수 있다.
            TNSY_music_ndcg, TNSY_tag_ndcg, TNSY_score, TYSN_music_ndcg, TYSN_tag_ndcg, TYSN_score, TYSY_music_ndcg, TYSY_tag_ndcg, TYSY_score, TNSN_music_ndcg, TNSN_tag_ndcg, TNSN_score, TNSY_title_y, TNSY_title_n, TYSN_title_y, TYSN_title_n, TYSY_title_y, TYSY_title_n, TNSN_title_y, TNSN_title_n = self._eval_detail(gt_fname, rec_fname, qt_fname)
            print("")
            print(f"Music nDCG: {music_ndcg:.6}")
            print(f"Tag nDCG: {tag_ndcg:.6}")
            print(f"Score: {score:.6}\n")

            print("태그X 곡O")
            print("TNSY title yes: ", TNSY_title_y)
            print("TNSY title no: ", TNSY_title_n)
            print(f"TNSY Music nDCG: {TNSY_music_ndcg:.6}")
            print(f"TNSY Tag nDCG: {TNSY_tag_ndcg:.6}")
            print(f"TNSY Score: {TNSY_score:.6}\n")

            print("태그O 곡X")
            print("TYSN title yes: ", TYSN_title_y)
            print("TYSN title no: ", TYSN_title_n)
            print(f"TYSN Music nDCG: {TYSN_music_ndcg:.6}")
            print(f"TYSN Tag nDCG: {TYSN_tag_ndcg:.6}")
            print(f"TYSN Score: {TYSN_score:.6}\n")

            print("태그O 곡O")
            print("TYSY title yes: ", TYSY_title_y)
            print("TYSY title no: ", TYSY_title_n)
            print(f"TYSY Music nDCG: {TYSY_music_ndcg:.6}")
            print(f"TYSY Tag nDCG: {TYSY_tag_ndcg:.6}")
            print(f"TYSY Score: {TYSY_score:.6}\n")

            print("태그X 곡X")
            print("TNSN title yes: ", TNSN_title_y)
            print("TNSN title no: ", TNSN_title_n)
            print(f"TNSN Music nDCG: {TNSN_music_ndcg:.6}")
            print(f"TNSN Tag nDCG: {TNSN_tag_ndcg:.6}")
            print(f"TNSN Score: {TNSN_score:.6}\n")
        except Exception as e:
            print(e)

    def run(self):
        answers_path = '/content/drive/My Drive/Colab Notebooks/테스트용/split_data/answers/val.json'
        questions_path = '/content/drive/My Drive/Colab Notebooks/테스트용/split_data/questions/val.json'
        results_path =  '/content/drive/My Drive/Colab Notebooks/테스트용/arena_data/results_test10.json'
        self.evaluate(answers_path, results_path, questions_path)

In [None]:
U_space = ArenaEvaluator()
U_space.run()

100%|██████████| 23015/23015 [00:02<00:00, 9299.42it/s]



Music nDCG: 0.233984
Tag nDCG: 0.457172
Score: 0.267462

태그X 곡O
TNSY title yes:  9550
TNSY title no:  0
TNSY Music nDCG: 0.264134
TNSY Tag nDCG: 0.526456
TNSY Score: 0.303483

태그O 곡X
TYSN title yes:  2618
TYSN title no:  0
TYSN Music nDCG: 0.0817263
TYSN Tag nDCG: 0.321091
TYSN Score: 0.117631

태그O 곡O
TYSY title yes:  8859
TYSY title no:  0
TYSY Music nDCG: 0.284458
TYSY Tag nDCG: 0.440795
TYSY Score: 0.307909

태그X 곡X
TNSN title yes:  1988
TNSN title no:  0
TNSN Music nDCG: 0.0647339
TNSN Tag nDCG: 0.376532
TNSN Score: 0.111504

