In [None]:
import os
import glob
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from irt import Beta3

# ------------------ 1) Configuration ------------------
# List of all genres to process
genres = [
    "action", "adventure", "animation", "comedy", "crime", "documentary",
    "drama", "family", "fantasy", "history", "horror", "music", "mystery",
    "romance", "science_fiction", "thriller", "war", "western"
]

# Base directory containing subfolders (one per algorithm)
base_dir = "results_top_n_genre"

# Create directories for saving outputs
output_plot_dir = "plots"
output_csv_dir = "genre_user_info"
os.makedirs(output_plot_dir, exist_ok=True)
os.makedirs(output_csv_dir, exist_ok=True)

# ------------------ 2) Function Definitions ------------------
def compute_sts_matrix_from_csv(csv_path):
    """
    Reads a CSV file containing columns:
      - ndcg_correct
      - ndcg_flipped
    Computes and returns the STS array using:
      STS = 1 - abs(ndcg_correct - ndcg_flipped)
    """
    df = pd.read_csv(csv_path)
    
    # Convert NDCG scores to torch tensors
    ndcg_correct = torch.tensor(df["ndcg_correct"].values, dtype=torch.float32)
    ndcg_flipped = torch.tensor(df["ndcg_flipped"].values, dtype=torch.float32)

    user_ids_array = df["user_id"]
    
    # Compute STS: 1 - abs(ndcg_correct - ndcg_flipped)
    sts_array = 1 - 10 * torch.abs(ndcg_correct - ndcg_flipped)
    
    return sts_array.numpy(), user_ids_array  # Return as a NumPy array


def loss_function(b4, df_matrix):
    """
    Simple loss: mean absolute difference between the predicted P(i,j)
    and the actual STS value in df_matrix.
    """
    loss_list = []
    for i in range(df_matrix.shape[0]):  # items/users
        for j in range(df_matrix.shape[1]):  # respondents/models
            pij_predicted = ICC_function(
                b4.abilities[j],        # ability of respondent/model j
                b4.difficulties[i],     # difficulty of item/user i
                b4.discriminations[i]   # discrimination of item/user i
            )
            # Compare to STS value
            res = abs(pij_predicted - df_matrix.iloc[i, j])
            loss_list.append(res)
    return np.mean(loss_list)


def ICC_function(abilities, difficulties, discriminations):
    """
    The specific item characteristic curve (ICC) function used by Beta3.
    """
    a = (1 - abilities) / abilities
    b = difficulties / (1 - difficulties)
    c = a * b
    d = c ** discriminations
    return 1 / (d + 1)


# ------------------ 3) Process Each Genre ------------------
for target_genre in genres:
    print(f"\nProcessing genre: {target_genre}")

    # ------------------ 4) Build the File Pattern ------------------
    pattern = os.path.join(base_dir, "*", f"per_user_ndcg_scores_*_{target_genre.lower()}.csv")
    csv_files = glob.glob(pattern)

    if not csv_files:
        print(f"⚠️ No CSV files found for genre '{target_genre}', skipping...")
        continue

    print(f"Found {len(csv_files)} CSV files for genre '{target_genre}'.")

    # ------------------ 5) Compute STS from CSV ------------------
    list_of_sts_arrays = []
    user_ids_array = pd.Series()

    for csv_file in csv_files:
        sts_array, user_ids_array_prev = compute_sts_matrix_from_csv(csv_file)

        if user_ids_array.empty:
            user_ids_array = user_ids_array_prev
        elif not user_ids_array.equals(user_ids_array_prev):
            raise ValueError(f"User ID mismatch in {csv_file}")

        list_of_sts_arrays.append(sts_array)
        print(f"Processed {csv_file}, STS array shape: {sts_array.shape}")

    # ------------------ 6) Combine STS Arrays into One Matrix ------------------
    final_matrix = np.vstack(list_of_sts_arrays).T  # shape: (num_users, num_models)
    print(f"Final STS matrix shape for '{target_genre}': {final_matrix.shape}")

    normalized_df = pd.DataFrame(final_matrix)

    # ------------------ 7) Beta3 IRT Pipeline ------------------
    subjects = normalized_df.shape[1]  # number of models/respondents
    items = normalized_df.shape[0]  # number of users/items

    # Initialize and run Beta3
    b4 = Beta3(
        learning_rate=10,
        epochs=5000,
        n_respondents=subjects,
        n_items=items,
        n_workers=-1,
        random_seed=1,
    )

    print(f"Fitting Beta3 model for '{target_genre}'...")
    b4.fit(normalized_df.values)
    print(f"Model fitting complete for '{target_genre}'.")

    loss = loss_function(b4, normalized_df)
    print(f"Final loss for '{target_genre}': {loss}")

    # ------------------ 8) Load and Merge User Data ------------------
    user_ids_ordered = user_ids_array.tolist()  # Extract user IDs in order
    user_info = pd.read_csv("user_info_existing.csv")  # Ensure this file contains `user_id, gender`
    gender_map = {"M": 0, "F": 1}
    user_info["gender"] = user_info["gender"].map(gender_map)

    user_beta3_results = pd.DataFrame({
        "user_id": user_ids_ordered,
        "discrimination": b4.discriminations,
        "difficulty": b4.difficulties
    })

    merged_df = user_beta3_results.merge(user_info, on="user_id", how="left")

    # ------------------ 9) Plot Results ------------------
    disc_values = np.array(merged_df["discrimination"])
    difficulty_values = np.array(merged_df["difficulty"])
    user_genders = np.array(merged_df["gender"])
    colors = np.where(user_genders == 0, "red", "blue")  # Male = Red, Female = Blue

    plt.figure(figsize=(12, 7))
    plt.scatter(disc_values, difficulty_values, c=colors, alpha=0.7)
    plt.title(f"Discrimination vs. Difficulty ({target_genre.capitalize()}, Colored by Gender)")
    plt.xlabel("Discrimination")
    plt.ylabel("Difficulty")
    plt.grid(alpha=0.3)

    import matplotlib.patches as mpatches
    red_patch = mpatches.Patch(color="red", label="Male")
    blue_patch = mpatches.Patch(color="blue", label="Female")
    plt.legend(handles=[red_patch, blue_patch])

    output_plot_path = os.path.join(output_plot_dir, f"discrimination_difficulty_plot_{target_genre}.png")
    plt.savefig(output_plot_path, dpi=300, bbox_inches="tight")
    plt.close()

    # ------------------ 10) Save Merged Data ------------------
    output_csv_path = os.path.join(output_csv_dir, f"genre_user_info_merged_{target_genre}.csv")
    merged_df.to_csv(output_csv_path, index=False)

    print(f"✅ Processed '{target_genre}': Plot saved at {output_plot_path}, Data saved at {output_csv_path}")

print("\n🎉 All genres processed successfully!")



Processing genre: action
Found 5 CSV files for genre 'action'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_action.csv, STS array shape: (848,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_action.csv, STS array shape: (848,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_action.csv, STS array shape: (848,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_action.csv, STS array shape: (848,)


  user_ids_array = pd.Series()


Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_action.csv, STS array shape: (848,)
Final STS matrix shape for 'action': (848, 5)
Fitting Beta3 model for 'action'...
Model fitting complete for 'action'.
Final loss for 'action': 0.06820740552152342
✅ Processed 'action': Plot saved at plots\discrimination_difficulty_plot_action.png, Data saved at genre_user_info\genre_user_info_merged_action.csv

Processing genre: adventure
Found 5 CSV files for genre 'adventure'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_adventure.csv, STS array shape: (668,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_adventure.csv, STS array shape: (668,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_adventure.csv, STS array shape: (668,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_adventure.csv, STS array shape: (668,)
Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_adventure.csv, STS array shape: (668,)
Final STS matrix s

  user_ids_array = pd.Series()


Model fitting complete for 'adventure'.
Final loss for 'adventure': 0.0449840162142009
✅ Processed 'adventure': Plot saved at plots\discrimination_difficulty_plot_adventure.png, Data saved at genre_user_info\genre_user_info_merged_adventure.csv

Processing genre: animation
Found 5 CSV files for genre 'animation'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_animation.csv, STS array shape: (349,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_animation.csv, STS array shape: (349,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_animation.csv, STS array shape: (349,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_animation.csv, STS array shape: (349,)
Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_animation.csv, STS array shape: (349,)
Final STS matrix shape for 'animation': (349, 5)
Fitting Beta3 model for 'animation'...


  user_ids_array = pd.Series()


Model fitting complete for 'animation'.
Final loss for 'animation': 0.02964843453668981
✅ Processed 'animation': Plot saved at plots\discrimination_difficulty_plot_animation.png, Data saved at genre_user_info\genre_user_info_merged_animation.csv

Processing genre: comedy
Found 5 CSV files for genre 'comedy'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_comedy.csv, STS array shape: (838,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_comedy.csv, STS array shape: (838,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_comedy.csv, STS array shape: (838,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_comedy.csv, STS array shape: (838,)
Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_comedy.csv, STS array shape: (838,)
Final STS matrix shape for 'comedy': (838, 5)
Fitting Beta3 model for 'comedy'...


  user_ids_array = pd.Series()


Model fitting complete for 'comedy'.
Final loss for 'comedy': 0.06953154167612305
✅ Processed 'comedy': Plot saved at plots\discrimination_difficulty_plot_comedy.png, Data saved at genre_user_info\genre_user_info_merged_comedy.csv

Processing genre: crime
Found 5 CSV files for genre 'crime'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_crime.csv, STS array shape: (638,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_crime.csv, STS array shape: (638,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_crime.csv, STS array shape: (638,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_crime.csv, STS array shape: (638,)
Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_crime.csv, STS array shape: (638,)
Final STS matrix shape for 'crime': (638, 5)
Fitting Beta3 model for 'crime'...


  user_ids_array = pd.Series()


Model fitting complete for 'crime'.
Final loss for 'crime': 0.0424195473273334
✅ Processed 'crime': Plot saved at plots\discrimination_difficulty_plot_crime.png, Data saved at genre_user_info\genre_user_info_merged_crime.csv

Processing genre: documentary
Found 5 CSV files for genre 'documentary'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_documentary.csv, STS array shape: (103,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_documentary.csv, STS array shape: (103,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_documentary.csv, STS array shape: (103,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_documentary.csv, STS array shape: (103,)
Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_documentary.csv, STS array shape: (103,)
Final STS matrix shape for 'documentary': (103, 5)
Fitting Beta3 model for 'documentary'...


  user_ids_array = pd.Series()


Model fitting complete for 'documentary'.
Final loss for 'documentary': 0.002774931639801956
✅ Processed 'documentary': Plot saved at plots\discrimination_difficulty_plot_documentary.png, Data saved at genre_user_info\genre_user_info_merged_documentary.csv

Processing genre: drama
Found 5 CSV files for genre 'drama'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_drama.csv, STS array shape: (905,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_drama.csv, STS array shape: (905,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_drama.csv, STS array shape: (905,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_drama.csv, STS array shape: (905,)
Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_drama.csv, STS array shape: (905,)
Final STS matrix shape for 'drama': (905, 5)
Fitting Beta3 model for 'drama'...


  user_ids_array = pd.Series()


Model fitting complete for 'drama'.
Final loss for 'drama': 0.0717567388201643
✅ Processed 'drama': Plot saved at plots\discrimination_difficulty_plot_drama.png, Data saved at genre_user_info\genre_user_info_merged_drama.csv

Processing genre: family
⚠️ No CSV files found for genre 'family', skipping...

Processing genre: fantasy
Found 5 CSV files for genre 'fantasy'.
Processed results_top_n_genre\CKE\per_user_ndcg_scores_CKE_fantasy.csv, STS array shape: (180,)
Processed results_top_n_genre\KGAT\per_user_ndcg_scores_KGAT_fantasy.csv, STS array shape: (180,)
Processed results_top_n_genre\KGCN\per_user_ndcg_scores_KGCN_fantasy.csv, STS array shape: (180,)
Processed results_top_n_genre\KGIN\per_user_ndcg_scores_KGIN_fantasy.csv, STS array shape: (180,)
Processed results_top_n_genre\NCFKG\per_user_ndcg_scores_NCFKG_fantasy.csv, STS array shape: (180,)
Final STS matrix shape for 'fantasy': (180, 5)
Fitting Beta3 model for 'fantasy'...


  user_ids_array = pd.Series()
