In [None]:
# 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import json
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
import pickle
import os
from IPython.display import clear_output, display
from collections import Counter
import re


class MelonMusicRecommender:
    def __init__(self, data_path='/content/drive/MyDrive/Graduation_Project/Dataset/'):
        self.data_path = data_path
        self.song_meta = None
        self.train_data = None
        self.val_data = None
        self.test_data = None
        self.song_embeddings = None
        self.song_to_idx = {}
        self.idx_to_song = {}
        #self.idx_to_song_list = []
        self.vocabulary = set()
        self.word_to_idx = {}
        self.idx_to_word = {}
        self.genre_map = {}
        self.word_counts = None
        self.neg_sampling_p = None
        self.neg_sampling_cdf = None
        self.rng = np.random.default_rng()

        # 재현성 일관화를 위한 RNG/seed 관리
        self.seed = 42
        self.rng = np.random.default_rng(self.seed)
        np.random.seed(self.seed)
        tf.random.set_seed(self.seed)
        # 전처리 이후 살아남은 곡 카탈로그
        self.kept_song_ids = set()

    def clean_text(self, text):
        text = re.sub(r'\s*\([^)]*\)\s*', '', text)
        return text.strip()

    def load_data(self):
        print("Loading dataset...")

        with open(os.path.join(self.data_path, 'song_meta.json'), 'r', encoding='utf-8') as f:
            song_meta_list = json.load(f)

        self.song_meta = {}
        for song in song_meta_list:
            self.song_meta[str(song['id'])] = song

        with open(os.path.join(self.data_path, 'train.json'), 'r', encoding='utf-8') as f:
            self.train_data = json.load(f)
        with open(os.path.join(self.data_path, 'val.json'), 'r', encoding='utf-8') as f:
            self.val_data = json.load(f)
        with open(os.path.join(self.data_path, 'test.json'), 'r', encoding='utf-8') as f:
            self.test_data = json.load(f)

        genre_file_path = os.path.join(self.data_path, 'genre_gn_all.json')
        try:
            with open(genre_file_path, 'r', encoding='utf-8') as f:
                self.genre_map = json.load(f)
            print(f"Loaded {len(self.genre_map)} genre mappings.")
        except FileNotFoundError:
            print(f"Error: The file {genre_file_path} was not found.")
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON from '{genre_file_path}': {e}")

        print(f"Loaded {len(self.song_meta)} songs")
        print(f"Train playlists: {len(self.train_data)}")
        print(f"Val playlists: {len(self.val_data)}")
        print(f"Test playlists: {len(self.test_data)}")

    """
    태그 전용 전처리 (최종본)
    규칙:
      0) '_' 포함 태그 드롭
      1) 중복 태그는 1회만 유지
      2) 끝 글자가 '한' 또는 '인'이면 제거 (ex. 몽환적인 -> 몽환적)
         + '적인' -> '적' 축약
      3) 포함 관계: 기존 태그/새 태그가 서로 포함이면 더 짧은 쪽만 유지(접두/접미 기준)
      4) 연도 규격화: 1990, 1990년도, 90s, '90 -> 1990 / '80-'90 -> 1980, 1990
      5) 곡당 태그가 30개 초과면 상위 30개만, 태그 수가 3개 이하인 곡은 drop
    추가:
      - alias/typo 통합 (카페/까페, 캐럴/캐롤, Pop/팝 등)
      - 방송/브랜드성 약신호 드롭
      - vocab/네거티브 샘플링 분포 안정화
      - 전처리 통계 및 샘플 100곡 출력
    """
    def preprocess_data(self, topk_per_song: int = 30, min_tags_per_song: int = 4):
      import re
      import numpy as np
      from collections import Counter, defaultdict

      print("Preprocessing data (tags only)...")

      # ---------- 정규화 사전 ----------
      ALIAS = {
          "pop": "팝",
          "jpop": "제이팝",
          "kpop": "케이팝",
          "jazz": "재즈",
          "hiphop": "힙합",
          "rnb": "알앤비",
          "r&b": "알앤비",
          "soul": "소울",
          "cafe": "카페",
          "bgm": "배경음악",
          "rock": "락",
          "ost": "OST",
          "christmas": "크리스마스",
          "carol": "캐롤",
          "xmas": "크리스마스",
          "fall": "가을",
          "moon": "달"
      }
      TYPO = {
          "까페": "카페",
          "캐럴": "캐롤",
          "따듯": "따뜻",
          "알엔비": "알앤비",
          "pop": "팝",
          "jazz": "재즈",
          "hiphop": "힙합",
          "jpop": "제이팝",
          "jpop.": "제이팝",
          "rnb": "알앤비",
          "rnb.": "알앤비",
          "r&b": "알앤비",
          "r n b": "알앤비",
          "rn b": "알앤비",
          "rn'b": "알앤비",
          "hip hop": "힙합",
          "j-pop": "제이팝",
          "k-pop": "케이팝",
          "cafe": "카페",
          "bgm": "배경음악",
          "j-pop.": "제이팝",
          "j-pop,": "제이팝",
      }
      STOP_SINGLE = {"r", "와", "라"}  # 의미 약한 단일 토큰 제거
      STOP_WEAK = {
          "추천","인기","명곡","애창곡","띵곡","차트","장르불문","좋은","기분","노래","음악",
          "카카오톡",  # 의미 없는 토큰 제거
      }
      BRANDS = {"mbc", "jtbc", "fm4u", "오픈채팅", "차트100"}  # 약신호 드롭

      # ---------- 정규표현식 ----------
      # '90, 1990, 1990년도, 1990s, '90s
      YEAR_SINGLE_PAT = re.compile(r"^'?(\d{2}|\d{4})(?:년|년도|s)?$")
      # '80-'90, '80~'90
      YEAR_RANGE_PAT  = re.compile(r"^'(\d{2})\s*[-~]\s*'(\d{2})$")
      # 1990년대, 90s, '90s
      DECADE_PAT      = re.compile(r"^'?(\d{2}|\d{4})\s*(?:년대|s)$")

      def normalize_year_token(tok: str):
          """
          연도/연대 규격화. 반환: 리스트(0,1,2개)
          - '90 -> 1990  (두 자리 연도는 1900대로 해석)
          - 1990s, 1990년대 -> 1990
          - '80-'90 -> 1980, 1990
          """
          t = tok.strip().lower()

          # 범위형 '80-'90
          mrange = YEAR_RANGE_PAT.match(t)
          if mrange:
              a, b = mrange.groups()
              ya, yb = 1900 + int(a), 1900 + int(b)
              return [str(ya), str(yb)]

          # 단일 연대(년대/s)
          mdec = DECADE_PAT.match(t)
          if mdec:
              g = mdec.group(1)
              y = 1900 + int(g) if len(g) == 2 else int(g)
              return [str(y)]

          # 단일 연도
          m = YEAR_SINGLE_PAT.match(t)
          if m:
              g = m.group(1)
              if len(g) == 2:
                  return [str(1900 + int(g))]
              elif len(g) == 4:
                  return [g]
          return []

      def normalize_suffix_ko(tok: str):
          """
          한국어 접미 간단 정규화:
          - '적인' -> '적'
          - 끝 '한' 또는 '인' 제거 (예: 몽환적인 -> 몽환적, 포근한 -> 포근)
          """
          t = tok
          if t.endswith("적인") and len(t) >= 3:
              t = t[:-2]  # '적인' -> '적'
          if t.endswith("한") and len(t) > 1:
              t = t[:-1]
          elif t.endswith("인") and len(t) > 1:
              t = t[:-1]
          return t

      def finalize_token(p: str):
          """
          최종 토큰 정리: lower, typo→alias→brand/stop 제거, 연대 규격화 검사 등
          """
          t = p.strip()
          if not t:
              return []

          # 0) '_' 포함 태그 드롭
          if "_" in t:
              return []

          # 라틴계는 소문자화
          t_low = t.lower()

          # typo/alias 정규화
          t_low = TYPO.get(t_low, t_low)
          t_low = ALIAS.get(t_low, t_low)

          # 브랜드/방송/약신호 드롭
          if t_low in BRANDS:
              return []

          # 너무 짧은 한 글자(숫자 제외) 제거
          if len(t_low) == 1 and not t_low.isdigit():
              if t_low in STOP_SINGLE:
                  return []
              # 숫자 1자리(연도 아님)는 신호 약하므로 제거
              return []

          # 연/연대 규격화
          years = normalize_year_token(t_low)
          if years:
              return years

          # 한국어 접미 정규화
          t2 = normalize_suffix_ko(t_low)

          # 다시 '_' 검사
          if "_" in t2 or not t2:
              return []

          # 불용어 처리
          if t2 in BRANDS or t2 in STOP_WEAK:
              return []

          return [t2]



      def tokenize_tag(raw: str):
          """
          원시 태그 -> 다중 구분자 분리 -> 각 토큰 finalize
          """
          txt = self.clean_text(raw) if hasattr(self, "clean_text") else str(raw)
          parts = re.split(r"[\/,\s#]+", txt)
          out = []
          for p in parts:
              if not p:
                  continue
              out.extend(finalize_token(p))
          return out

      # ---------- 1단계: 곡별 태그 카운트 집계 (플레이리스트 단위 중복 1회) ----------
      song_tag_counter = defaultdict(Counter)  # song_id(str) -> Counter(tag->count)

      for pl in self.train_data:
          raw_tags = pl.get("tags", [])
          tokens = []
          for t in raw_tags:
              tokens.extend(tokenize_tag(t))

          # (규칙 1) 플레이리스트 안 동일 태그 중복 1회
          tokens = list(dict.fromkeys(tokens))

          for sid in pl.get("songs", []):
              sid = str(sid)
              song_tag_counter[sid].update(tokens)

      # ---------- 포함 관계 판단 함수 (접두/접미 기준만 허용) ----------
      def contained(shorter: str, longer: str) -> bool:
          if len(shorter) >= len(longer):
              return False
          return longer.startswith(shorter) or longer.endswith(shorter)

      # ---------- 포함 규칙 적용 + topK 제한 + min 필터 ----------
      def compact_by_containment(counter: Counter) -> Counter:
          """
          출현 빈도 내림차순, 길이 오름차순으로 훑으며
          포함(접두/접미) 관계면 더 짧은 토큰만 유지.
          """
          items = list(counter.items())
          items.sort(key=lambda x: (-x[1], len(x[0]), x[0]))

          kept = []  # [(tag, count)]
          for tag, cnt in items:
              keep_tag = tag
              replaced = False
              for i, (ex_tag, ex_cnt) in enumerate(kept):
                  # ex_tag ⊂ keep_tag (ex_tag가 더 짧고 접두/접미로 포함)
                  if contained(ex_tag, keep_tag):
                      # 새 태그는 버리고 기존 짧은 태그 유지
                      replaced = True
                      break
                  # keep_tag ⊂ ex_tag (keep_tag가 더 짧음)
                  if contained(keep_tag, ex_tag):
                      # 기존 긴 태그를 짧은 걸로 교체, 빈도 합산(보수적으로)
                      kept[i] = (keep_tag, ex_cnt + cnt)
                      replaced = True
                      break
              if not replaced:
                  kept.append((keep_tag, cnt))
          return Counter(dict(kept))

      filtered_song_tags = {}
      for sid, ctr in song_tag_counter.items():
          if not ctr:
              continue

          # (규칙 3) 포함관계 정리
          ctr2 = compact_by_containment(ctr)

          # (규칙 5 전반부) 상위 topk 유지 (빈도 기준)
          top_items = ctr2.most_common(topk_per_song)
          tags = [t for t, _ in top_items]

          # (규칙 5 후반부) 태그 수가 3개 이하라면 drop (min_tags_per_song=4)
          if len(tags) < min_tags_per_song:
              continue

          filtered_song_tags[sid] = tags

      # ---------- 통계/샘플 출력 ----------
      kept_songs = len(filtered_song_tags)
      tag_lens = [len(v) for v in filtered_song_tags.values()]
      mean_len = float(np.mean(tag_lens)) if tag_lens else 0.0
      median_len = float(np.median(tag_lens)) if tag_lens else 0.0

      print(f"[RESULT] kept songs: {kept_songs}")
      print(f"[RESULT] avg #tags per song: {mean_len:.2f}")
      print(f"[RESULT] median #tags per song: {median_len:.0f}")
      print("[SAMPLE 100 songs' tags]")
      for i, (sid, tags) in enumerate(filtered_song_tags.items()):
          print(f"{sid}: {tags}")
          if i >= 99:
              break

      # ---------- 학습 파이프라인 연동 ----------
      # 문장/어휘/인덱스/네거티브 샘플링 분포 구성
      self.song_sentences = filtered_song_tags
      self.vocabulary = set(t for tags in filtered_song_tags.values() for t in tags)

      if not self.vocabulary:
          print("Warning: Vocabulary is empty after filtering.")
          return

      global_counter = Counter(t for tags in filtered_song_tags.values() for t in tags)
      vocab_sorted = sorted(self.vocabulary, key=lambda x: (-global_counter[x], x))
      self.word_to_idx = {w: i for i, w in enumerate(vocab_sorted)}
      self.idx_to_word = {i: w for w, i in self.word_to_idx.items()}

      unique_songs = list(filtered_song_tags.keys())
      self.song_to_idx = {s: i for i, s in enumerate(unique_songs)}
      self.idx_to_song = {i: s for s, i in self.song_to_idx.items()}

      # 전처리 살아남은 곡 카탈로그 저장
      self.kept_song_ids = set(unique_songs)

      # 네거티브 샘플링 분포 (0.75)
      unigrams = np.array([global_counter[w] for w in vocab_sorted], dtype=np.float64)
      p = np.power(unigrams, 0.75); p /= p.sum()
      self.neg_sampling_p = p
      self.neg_sampling_cdf = np.cumsum(p) / np.sum(p)

      print(f"Vocabulary size: {len(self.vocabulary)}")
      print(f"Total songs after filtering: {len(self.song_sentences)}")
      any_sid = next(iter(self.song_sentences)) if self.song_sentences else None
      if any_sid:
          print(f"Sample sentence for {any_sid}: {self.song_sentences[any_sid][:10]}")
      print("Negative sampling distribution and CDF created.")


    def build_sgns_model(self, embedding_size=128, learning_rate=0.1):
        print("Building SGNS model...")
        vocab_size = len(self.vocabulary)

        class SkipGram(keras.Model):
            def __init__(self, vocab_size, embedding_size):
                super().__init__()
                self.target_embedding = layers.Embedding(vocab_size, embedding_size, name="target_embedding")
                self.context_embedding = layers.Embedding(vocab_size, embedding_size, name="context_embedding")

            def call(self, inputs):
                target, context = inputs
                target  = tf.cast(target,  tf.int32)
                context = tf.cast(context, tf.int32)
                t = self.target_embedding(target)   # (B, D)
                c = self.context_embedding(context) # (B, D)
                logits = tf.reduce_sum(t * c, axis=1)  # (B,)
                return tf.cast(logits, tf.float32)

        self.model = SkipGram(vocab_size, embedding_size)
        self.model.compile(
            optimizer=keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov = True),
            loss=keras.losses.BinaryCrossentropy(from_logits=True)
        )
        return self.model


    def generate_training_data(self, window_size=5, num_negative_samples=10, max_samples=1000000):
        import random
        print("Generating training data...")

        positive_pairs = []
        song_items = list(self.song_sentences.items())

        for song_id, sentence in tqdm(song_items, desc="Creating positive samples"):
            if not sentence:
                continue

            word_indices = [self.word_to_idx[word] for word in sentence if word in self.word_to_idx]
            if not word_indices:
                continue

            for i, target in enumerate(word_indices):
                start = max(0, i - window_size)
                end = min(len(word_indices), i + window_size + 1)
                for j in range(start, end):
                    if i != j:
                        positive_pairs.append((target, word_indices[j], 1))
                        if len(positive_pairs) >= max_samples:
                            break
                if len(positive_pairs) >= max_samples:
                    break
            if len(positive_pairs) >= max_samples:
                break

        if not positive_pairs:
            print("Warning: No positive pairs created. Check data format.")
            return np.array([]), np.array([]), np.array([])

        print(f"Created {len(positive_pairs)} positive pairs")

        negative_pairs = []
        vocab_size = len(self.vocabulary)
        if self.neg_sampling_cdf is None or len(self.neg_sampling_cdf) != vocab_size:
            print("Error: Negative sampling CDF is not correctly initialized.")
            return np.array([]), np.array([]), np.array([])

        num_neg_total = len(positive_pairs) * num_negative_samples
        num_neg_total = min(num_neg_total, max_samples * num_negative_samples)
        print(f"Creating {num_neg_total} negative samples using vectorized approach...")

        batch_size = 10000
        for i in tqdm(range(0, num_neg_total, batch_size), desc="Creating negative samples"):
            batch_end = min(i + batch_size, num_neg_total)
            batch_size_actual = batch_end - i
            selected_indices = self.rng.integers(0, len(positive_pairs), size=batch_size_actual)
            batch_targets = np.array([positive_pairs[idx][0] for idx in selected_indices], dtype=np.int32)

            u = self.rng.random(batch_size_actual)
            negs = np.searchsorted(self.neg_sampling_cdf, u, side="right").astype(np.int32)

            mask = (negs == batch_targets)
            while mask.any():
                u = self.rng.random(mask.sum())
                negs[mask] = np.searchsorted(self.neg_sampling_cdf, u, side="right").astype(np.int32)
                mask = (negs == batch_targets)

            negative_pairs.extend(map(tuple, np.stack(
                [batch_targets, negs, np.zeros_like(batch_targets, dtype=np.float32)], axis=1)))

        print(f"Created {len(negative_pairs)} negative pairs")

        all_pairs = positive_pairs + negative_pairs

        #np.random.shuffle(all_pairs)
        idx_perm = self.rng.permutation(len(all_pairs))
        all_pairs = [all_pairs[i] for i in idx_perm]

        if len(all_pairs) > max_samples * (1 + num_negative_samples):
            all_pairs = all_pairs[:max_samples * (1 + num_negative_samples)]

        targets = np.array([p[0] for p in all_pairs], dtype=np.int32)
        contexts = np.array([p[1] for p in all_pairs], dtype=np.int32)
        labels = np.array([p[2] for p in all_pairs], dtype=np.float32)

        print(f"Total training samples: {len(targets)}")
        print(f"Positive ratio: {np.mean(labels):.2%}")

        return targets, contexts, labels

    def train_model(self, epochs=15, batch_size=4096, quick_mode=False):
        """
        tf.data 파이프라인으로 GPU 학습
        """
        print("Training SGNS model with tf.data...")

        targets, contexts, labels = self.generate_training_data(
            window_size=5,
            num_negative_samples=10,
            max_samples=1000000
        )
        if len(targets) == 0:
            print("Error: No training data generated. Please check your data format.")
            return None

        targets  = targets.astype(np.int32)
        contexts = contexts.astype(np.int32)
        labels   = labels.astype(np.float32)

        n = len(labels)
        #idx = np.arange(n)
        idx = self.rng.permutation(n)
        #np.random.shuffle(idx)
        cut = int(n * 0.8)
        tr, va = idx[:cut], idx[cut:]

        train_ds = tf.data.Dataset.from_tensor_slices(
            ((targets[tr], contexts[tr]), labels[tr])
        ).shuffle(200_000, seed=self.seed, reshuffle_each_iteration=False).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        #).shuffle(200_000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

        val_ds = tf.data.Dataset.from_tensor_slices(
            ((targets[va], contexts[va]), labels[va])
        ).batch(batch_size).prefetch(tf.data.AUTOTUNE)

        self.build_sgns_model(embedding_size=128, learning_rate=0.1)

        class VisualizationCallback(keras.callbacks.Callback):
            def __init__(self):
                self.losses = []
                self.val_losses = []
                self.epochs_list = []
            def on_epoch_end(self, epoch, logs=None):
                self.losses.append(logs.get('loss'))
                self.val_losses.append(logs.get('val_loss'))
                self.epochs_list.append(epoch + 1)
                fig, axes = plt.subplots(1, 2, figsize=(10, 4))
                axes[0].plot(self.epochs_list, self.losses, 'b-', label='Train')
                axes[0].plot(self.epochs_list, self.val_losses, 'r-', label='Val')
                axes[0].set_title('Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)
                if len(self.losses) > 1:
                    diff = np.array(self.val_losses) - np.array(self.losses)
                    axes[1].plot(self.epochs_list, diff, 'g-')
                    axes[1].axhline(0, color='k', ls='--', alpha=0.5)
                    axes[1].set_title('Val - Train')
                    axes[1].grid(True, alpha=0.3)
                plt.tight_layout(); plt.show()

        viz_callback = VisualizationCallback()
        early_stopping = keras.callbacks.EarlyStopping(
            monitor='val_loss', patience=5, restore_best_weights=True, verbose=1
        )

        history = self.model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=epochs,
            callbacks=[viz_callback, early_stopping],
            verbose=1
        )

        self.plot_final_results(history)
        self.create_song_embeddings()   # 검색 행렬 생성
        return history

    def plot_final_results(self, history):
        print("\n" + "="*60)
        print("TRAINING COMPLETED - FINAL RESULTS")
        print("="*60)

        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        axes[0].plot(history.history['loss'], 'b-', label='Training Loss', linewidth=2)
        axes[0].plot(history.history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
        axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
        axes[0].set_title('Final Training History'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

        min_val_loss_epoch = np.argmin(history.history['val_loss'])
        axes[0].scatter(min_val_loss_epoch, history.history['val_loss'][min_val_loss_epoch],
                        color='red', s=100, zorder=5)
        axes[0].annotate(f'Best: {history.history["val_loss"][min_val_loss_epoch]:.6f}',
                         xy=(min_val_loss_epoch, history.history['val_loss'][min_val_loss_epoch]),
                         xytext=(10, 10), textcoords='offset points',
                         bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.7),
                         arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))

        axes[1].hist(history.history['loss'], bins=20, alpha=0.5, label='Training', color='blue')
        axes[1].hist(history.history['val_loss'], bins=20, alpha=0.5, label='Validation', color='red')
        axes[1].set_xlabel('Loss Value'); axes[1].set_ylabel('Frequency')
        axes[1].set_title('Loss Distribution'); axes[1].legend(); axes[1].grid(True, alpha=0.3)

        summary_text = f"""
        Final Training Metrics:

        • Total Epochs: {len(history.history['loss'])}
        • Final Train Loss: {history.history['loss'][-1]:.6f}
        • Final Val Loss: {history.history['val_loss'][-1]:.6f}
        • Best Val Loss: {min(history.history['val_loss']):.6f}
        • Best Epoch: {np.argmin(history.history['val_loss']) + 1}

        • Initial Loss: {history.history['loss'][0]:.6f}
        • Loss Reduction: {(1 - history.history['loss'][-1]/history.history['loss'][0])*100:.2f}%
        """
        axes[2].text(0.1, 0.5, summary_text, fontsize=11,
                     transform=axes[2].transAxes, verticalalignment='center',
                     bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.3))
        axes[2].axis('off')
        axes[2].set_title('Training Summary')
        plt.suptitle('SGNS Model Training - Final Results', fontsize=16, fontweight='bold')
        plt.tight_layout(); plt.show()

    def create_song_embeddings(self):
        """
        학습된 단어 임베딩을 사용하여 곡 임베딩 생성
        """
        print("Creating song embeddings...")
        word_embeddings = self.model.get_layer('target_embedding').get_weights()[0]

        self.song_embeddings = {}
        for song_id, sentence in tqdm(self.song_sentences.items(), desc="Computing song embeddings"):
            if len(sentence) > 0:
                word_indices = [self.word_to_idx[word] for word in sentence if word in self.word_to_idx]
                if word_indices:
                    song_embedding = np.mean(word_embeddings[word_indices], axis=0)
                    self.song_embeddings[song_id] = song_embedding

        # 추천/평가용 검색 행렬 준비 (정규화 + 매핑)
        self.build_search_matrix()

    # ---------- 추천/평가 가속을 위한 추가 메서드 ----------
    def build_search_matrix(self):
        """정규화 임베딩 행렬(E_norm)과 매핑 준비"""
        if not self.song_embeddings:
            self.E_norm = None
            return
        #self.idx_to_song_list = sorted(self.song_embeddings.keys(), key=int)
        self.idx_to_song = sorted(self.song_embeddings.keys(), key=int)
        self.song_to_idx = {sid: i for i, sid in enumerate(self.idx_to_song)}
        D = len(next(iter(self.song_embeddings.values())))
        E = np.zeros((len(self.idx_to_song), D), dtype=np.float32)
        for i, sid in enumerate(self.idx_to_song):
            E[i] = self.song_embeddings[sid]
        E /= (np.linalg.norm(E, axis=1, keepdims=True) + 1e-12)
        self.E_norm = E  # (N, D)

    def recommend_songs(self, playlist_songs, top_k=100):
        """
        빠른 추천: 정규화 임베딩 행렬과 내적 1회로 Top-K 추출
        """
        if not playlist_songs or getattr(self, "E_norm", None) is None:
            return []
        seed_idxs = [self.song_to_idx[str(s)] for s in playlist_songs if str(s) in self.song_to_idx]
        if not seed_idxs:
            return []

        q = self.E_norm[seed_idxs].mean(axis=0, dtype=np.float32)
        q /= (np.linalg.norm(q) + 1e-12)

        sims = self.E_norm @ q  # (N,)
        sims[np.array(seed_idxs)] = -1e-9  # 시드 제외

        K = min(top_k, self.E_norm.shape[0] - len(seed_idxs))
        if K <= 0:
            return []
        top_idx = np.argpartition(-sims, K)[:K]
        top_idx = top_idx[np.argsort(-sims[top_idx])]

        rec_ids = []
        for i in top_idx:
            sid = self.idx_to_song[i]
            #sid = self.idx_to_song_list[i]
            rec_ids.append(int(sid) if sid.isdigit() else sid)
        return rec_ids

    def calculate_ndcg(self, recommended, actual, k=100):
        dcg = 0.0
        for i, song in enumerate(recommended[:k]):
            if song in actual:
                dcg += 1.0 / np.log2(i + 2)
        idcg = sum([1.0 / np.log2(i + 2) for i in range(min(len(actual), k))])
        if idcg == 0:
            return 0.0
        return dcg / idcg

    # ---------- Metric helpers ----------
    def precision_at_k(self, recommended, actual, k=100):
        if k <= 0:
            return 0.0
        rec_k = recommended[:k]
        hits = len(set(rec_k) & set(actual))
        return hits / float(k)

    def recall_at_k(self, recommended, actual, k=100):
        if not actual:
            return 0.0
        rec_k = recommended[:k]
        hits = len(set(rec_k) & set(actual))
        return hits / float(len(actual))

    def hit_rate_at_k(self, recommended, actual, k=100):
        rec_k = set(recommended[:k])
        return 1.0 if len(rec_k & set(actual)) > 0 else 0.0

    """
    플레이리스트의 곡 중 mask_ratio(기본 30%)를 정답으로 마스킹하고,
    나머지 곡(70%)을 시드로 사용해 Top-K 추천 → nDCG@K를 계산.
    - min_len: 너무 짧은 리스트는 평가 제외 (기본 40)
    - seed: 재현성 위한 난수 시드
    """
    #def evaluate(self, test_playlists=None, mask_ratio=0.3, top_k_list=(100, 300, 500), min_len=40, seed=42):
    def evaluate(self, split='val', mask_ratio=0.3, top_k_list=(100, 300, 500), min_len_catalog=30, seed=42):
      print("Evaluating model with random masking...")

      # 1) 어떤 split 사용할지
      if isinstance(split, str):
          if split == 'test':
              test_playlists = self.test_data
          elif split == 'val':
              test_playlists = self.val_data
          else:
              raise ValueError("split must be 'val', 'test', or a list of playlists.")
      else:
          # 리스트 직접 전달 가능
          test_playlists = split

      rng = np.random.default_rng(seed)

      # 2) 카탈로그(전처리+임베딩 존재) 기준 집합
      kept = set(self.song_to_idx.keys())

      # 3) 카탈로그 내 곡 길이 기준으로 필터
      def to_catalog(pl):
          catalog_songs = [s for s in pl['songs'] if str(s) in kept]
          return catalog_songs

      playlists_catalog = []
      for pl in test_playlists:
          cs = to_catalog(pl)
          if len(cs) >= min_len_catalog:
              playlists_catalog.append(cs)

      if not playlists_catalog:
          print(f"No playlists with >= {min_len_catalog} kept songs.")
          return {}

      print(f"Playlists considered: {len(playlists_catalog)} (min_len_catalog={min_len_catalog})")

      # K별 누적 지표 저장용
      sums = {k: {"ndcg": 0.0, "prec": 0.0, "rec": 0.0, "hit": 0.0} for k in top_k_list}
      evaluated = 0
      max_k = max(top_k_list)

      for songs in tqdm(playlists_catalog, desc="Evaluating (random mask)"):
          m = max(1, int(np.ceil(len(songs) * mask_ratio)))
          idx = np.arange(len(songs))
          rng.shuffle(idx)
          answer_songs = [songs[i] for i in idx[:m]]
          input_songs  = [songs[i] for i in idx[m:]]
          if not input_songs:
              continue

          rec = self.recommend_songs(input_songs, top_k=max_k)

          for k in top_k_list:
              ndcg = self.calculate_ndcg(rec, answer_songs, k=k)
              prec = self.precision_at_k(rec, answer_songs, k=k)
              reca = self.recall_at_k(rec, answer_songs, k=k)
              hit  = self.hit_rate_at_k(rec, answer_songs, k=k)

              sums[k]["ndcg"] += ndcg
              sums[k]["prec"] += prec
              sums[k]["rec"]  += reca
              sums[k]["hit"]  += hit

          evaluated += 1

      if evaluated == 0:
          print("No playlists evaluated.")
          return {}

      results = {}
      for k in top_k_list:
          results[k] = {
              "nDCG":   sums[k]["ndcg"] / evaluated,
              "Prec":   sums[k]["prec"] / evaluated,
              "Recall": sums[k]["rec"]  / evaluated,
              "Hit":    sums[k]["hit"]  / evaluated,
          }

      print(f"\n=== Averages (split={split}, mask={int(mask_ratio*100)}%, len≥{min_len_catalog}, evaluated={evaluated}) ===")
      for k in top_k_list:
          r = results[k]
          print(f"K={k:>4}  nDCG={r['nDCG']:.4f}  Prec={r['Prec']:.4f}  Recall={r['Recall']:.4f}  HitRate={r['Hit']:.4f}")

      return results

In [None]:
# 메인 실행 코드
print("="*60)
print("MUSIC RECO MODEL")
print("Based on Data Embedding (SGNS)")
print("="*60 + "\n")

recommender = MelonMusicRecommender()

print("\n[Step 1/5] Loading Dataset...")
recommender.load_data()

print("\n[Step 2/5] Preprocessing Data...")
recommender.preprocess_data()

print("\n--- Sample of Preprocessed Song Sentences (20 samples) ---")
sample_count = 0
for song_id, sentence in recommender.song_sentences.items():
    print(f"Song ID: {song_id}")
    print(f"Sentence: {sentence}")
    print("-" * 20)
    sample_count += 1
    if sample_count >= 20:
        break
print("------------------------------------------------------------\n")

print("\n[Step 3/5] Training Model...")
history = recommender.train_model(
    epochs=30,
    batch_size=4096,   # GPU면 4096~8192 권장
    quick_mode=False
)


print("\n[Step 3-end] Saving Model...")
# 1) 단어(=태그) 임베딩 행렬 뽑기
W = recommender.model.get_layer('target_embedding').get_weights()[0]  # shape: [vocab_size, 128]

# 2) 인덱스↔단어 사전 가져오기
word_to_idx = recommender.word_to_idx   # {tag -> idx}
idx_to_word = {i:w for w,i in word_to_idx.items()}

# 3) 태그 임베딩을 딕셔너리로 저장 (정규화)
tag_embeds = {}
for i, w in idx_to_word.items():
    v = W[i].astype(np.float32)
    v = v / (np.linalg.norm(v) + 1e-12)
    tag_embeds[w] = v


save_path = '/content/drive/MyDrive/Graduation_Project/model_checkpoint/'
os.makedirs(save_path, exist_ok=True)

with open(os.path.join(save_path, 'tag_embeddings.pkl'), 'wb') as f:
    pickle.dump(tag_embeds, f)

with open(os.path.join(save_path, 'word_to_idx.json'), 'w', encoding="utf-8") as f:
  json.dump(word_to_idx, f, ensure_ascii=False)

with open(os.path.join(save_path, 'song_embeddings_adam_ver4.pkl'), 'wb') as f:
    pickle.dump(recommender.song_embeddings, f)

print("saved tag_embeddings.pkl & word_to_idx.json")