# Preprocessing the LastFM-360K Dataset

This notebook outlines the preprocessing steps used to create a synthetic [LastFM-360K dataset (2010)](http://ocelma.net/MusicRecommendationDataset/lastfm-360K.html/) with three gender classes (m/f/nb) as the sensitive attribute. A single-step filter is applied to the interaction data to remove sparse users/items and keep a dense, well-supported subset.

## Imports and Configuration

The configuration reassigns 10% of users to the non-binary class. The interaction data is filtered to retain users with at least 50 unique attributes and attributes with at least 20 unique users. Ratings are binarized using a 2.5 threshold on the 1–5 scale.
 

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import os

DATA_DIR = Path(os.getenv('PROJECT_ROOT', Path.cwd()))

NON_BINARY_FRAC = 0.1
RANDOM_SEED = 42
MIN_ARTISTS_PER_USER = 50
MIN_USERS_PER_ARTIST = 20
THRESHOLD_LABEL = 2.5

np.random.seed(RANDOM_SEED)

## Load User Data

User data is loaded and reduced to user IDs and gender labels, which serve as the sensitive attribute; rows with missing values are dropped and all other columns are discarded.

In [2]:
users_df = pd.read_csv(
    filepath_or_buffer='usersha1-profile.tsv',
    sep='\t',
    header=None,
    usecols=[0, 1],
    names=['user_id', 'gender']
).dropna()

## Create Non-Binary Gender Class

A synthetic non-binary attribute is created by randomly sampling 10% of users from the existing male and female populations while preserving their original ratio. Gender labels are displayed before and after the transformation.

In [3]:
male_count = users_df[users_df['gender'] == 'm'].shape[0]
female_count = users_df[users_df['gender'] == 'f'].shape[0]
total_users = male_count + female_count

print("=" * 60)
print("INITIAL GENDER DISTRIBUTION")
print("=" * 60)
print(f"{'Male:':<7} {male_count:<3} ({male_count/total_users*100:>4.1f}%)")
print(f"{'Female:':<8} {female_count:<3} ({female_count/total_users*100:>4.1f}%) \n")

num_non_binary = int(total_users * NON_BINARY_FRAC)

# Sample users to become non-binary (respecting existing gender ratio)
gender_counts = users_df['gender'].value_counts()
ratio_m_f = gender_counts['m'] / gender_counts['f']
num_nb_from_female = int(num_non_binary / (1 + ratio_m_f))
num_nb_from_male = num_non_binary - num_nb_from_female

print("=" * 60)
print("ASSIGNING NON-BINARY GENDERS")
print("=" * 60)
print(f"Sampling {NON_BINARY_FRAC*100:.0f}% of users to be non-binary.")
print(f"Sampling respects the existing M/F ratio ({ratio_m_f:.3f}):")
print(f"  - {num_nb_from_male} from male users")
print(f"  - {num_nb_from_female} from female users \n")

male_indices = users_df[users_df['gender'] == 'm'].sample(
    n=num_nb_from_male, random_state=RANDOM_SEED
).index
female_indices = users_df[users_df['gender'] == 'f'].sample(
    n=num_nb_from_female, random_state=RANDOM_SEED
).index

# Combine and assign non-binary
nb_indices = male_indices.union(female_indices)
users_df.loc[nb_indices, 'gender'] = 'nb'

male_count = users_df[users_df['gender'] == 'm'].shape[0]
female_count = users_df[users_df['gender'] == 'f'].shape[0]
nb_count = users_df[users_df['gender'] == 'nb'].shape[0]

assert total_users == male_count + female_count + nb_count, (
    f"Population mismatch after assigning non-binary genders. "
    f"Before: {total_users}, "
    f"After: {male_count + female_count + nb_count}"
)

print("=" * 60)
print("RESULTING GENDER DISTRIBUTION")
print("=" * 60)
print(f"{'Men:':<11} {male_count:<3} ({male_count/total_users*100:>4.1f}%)")
print(f"{'Women:':<12} {female_count:<3} ({female_count/total_users*100:>4.1f}%)")
print(f"{'Non-binary:':<12} {nb_count:<3} ({nb_count/total_users*100:>4.1f}%) \n")

INITIAL GENDER DISTRIBUTION
Male:   241642 (74.0%)
Female:  84930 (26.0%) 

ASSIGNING NON-BINARY GENDERS
Sampling 10% of users to be non-binary.
Sampling respects the existing M/F ratio (2.845):
  - 24165 from male users
  - 8492 from female users 

RESULTING GENDER DISTRIBUTION
Men:        217477 (66.6%)
Women:       76438 (23.4%)
Non-binary:  32657 (10.0%) 



## Load Interaction Data

User–item interactions are loaded by retaining only the relevant columns. Rows with missing values are dropped, and interactions are restricted to users present in the user data to ensure consistent user IDs across datasets.

In [4]:
items_df = pd.read_csv(
    filepath_or_buffer='usersha1-artmbid-artname-plays.tsv',
    sep='\t',
    header=None,
    usecols=[0, 2, 3], 
    names=['user_id', 'item_id', 'plays']
).dropna()

valid_users = set(users_df['user_id'].unique())
items_df = items_df[items_df['user_id'].isin(valid_users)]

## Interaction Filtering

The interaction data is filtered to retain a dense and well-supported subset. Users with fewer than 50 unique artists are removed, after which artists listened to by fewer than 20 users are discarded. Summary statistics are reported before and after filtering to verify the effect of these constraints.

In [5]:
users_before_filtering = set(items_df['user_id'].unique())
artists_per_user_before = items_df.groupby('user_id')['item_id'].nunique()
users_per_artist_before = items_df.groupby('item_id')['user_id'].nunique()

print("=" * 60)
print("BEFORE FILTERING")
print("=" * 60)
print(f"Total interactions: {len(items_df):,}")
print(f"Min artists per user: {artists_per_user_before.min()}")
print(f"Min users per artist: {users_per_artist_before.min()}\n")

active_users = artists_per_user_before[artists_per_user_before >= MIN_ARTISTS_PER_USER].index
items_df = items_df[items_df['user_id'].isin(active_users)]

# Recalculate artist statistics after user filtering
users_per_artist_filtered = items_df.groupby('item_id')['user_id'].nunique()

popular_artists = users_per_artist_filtered[users_per_artist_filtered >= MIN_USERS_PER_ARTIST].index
items_df = items_df[items_df['item_id'].isin(popular_artists)]

users_after_filtering = set(items_df['user_id'].unique())
artists_per_user_after = items_df.groupby('user_id')['item_id'].nunique()
users_per_artist_after = items_df.groupby('item_id')['user_id'].nunique()

print("=" * 60)
print("AFTER FILTERING")
print("=" * 60)
print(f"Total interactions: {len(items_df):,}")
print(f"Min artists per user: {artists_per_user_after.min()}")
print(f"Min users per artist: {users_per_artist_after.min()}\n")

BEFORE FILTERING
Total interactions: 15,947,926
Min artists per user: 1
Min users per artist: 1

AFTER FILTERING
Total interactions: 6,792,309
Min artists per user: 6
Min users per artist: 20



## Update Users DataFrame

User data is synchronized with the filtered interaction data by removing users that were eliminated during interaction filtering.

In [6]:
removed_users = users_before_filtering.difference(users_after_filtering)

if len(removed_users) > 0:
    print(f"{len(removed_users)} users removed during filtering. Updating users dataframe... \n")
    valid_user_ids = items_df['user_id'].unique()
    users_df = users_df[users_df['user_id'].isin(valid_user_ids)].reset_index(drop=True)
elif len(removed_users) == 0:
    print("No users removed during filtering. Proceeding without updating users dataframe. \n")
else:
    raise ValueError("Unexpected condition: more users after filtering than before. \n")

194156 users removed during filtering. Updating users dataframe... 



## Label Construction

Play counts are log-transformed to reduce skew and then rescaled to a 1–5 range. The rescaled values are binarized using a 2.5 threshold to create the final labels. Intermediate columns are removed and the final label distribution is reported.

In [7]:
items_df['log_plays'] = np.log1p(items_df['plays'])

# Min-max normalize to [1,5]
min_log = items_df['log_plays'].min()
max_log = items_df['log_plays'].max()
items_df['rating'] = 1 + (items_df['log_plays'] - min_log) * 4 / (max_log - min_log)

items_df['label'] = (items_df['rating'] >= THRESHOLD_LABEL).astype(int)

items_df = items_df.drop(columns=['plays', 'log_plays', 'rating'])

total_labels = len(items_df)
num_label_1 = (items_df['label'] == 1).sum()
num_label_0 = (items_df['label'] == 0).sum()
pct_label_1 = num_label_1 / total_labels * 100
pct_label_0 = num_label_0 / total_labels * 100

print("=" * 60)
print("LABEL DISTRIBUTION")
print("=" * 60)
print(f"Label 1: {num_label_1:,} ({pct_label_1:.1f}%)")
print(f"Label 0: {num_label_0:,} ({pct_label_0:.1f}%) \n")

LABEL DISTRIBUTION
Label 1: 2,665,018 (39.2%)
Label 0: 4,127,291 (60.8%) 



## ID Remapping and Consistency Checks

User and item identifiers are remapped to consecutive integer ranges to ensure consistent indexing and model compatibility. Musical items, originally represented by non-numeric identifiers, are encoded as integer IDs. Consistency checks are performed to verify that user and item IDs are contiguous and aligned across the interaction and user data.

In [8]:
user_ids = users_df['user_id'].unique()
user_id_map = {old_id: new_id for new_id, old_id in enumerate(user_ids)}
users_df['user_id'] = users_df['user_id'].map(user_id_map)
items_df['user_id'] = items_df['user_id'].map(user_id_map)

artist_ids = items_df['item_id'].unique()
artist_id_map = {old_id: new_id for new_id, old_id in enumerate(artist_ids)}
items_df['item_id'] = items_df['item_id'].map(artist_id_map)

print("=" * 60)
print("RANGE CONSISTENCY CHECK")
print("=" * 60)
print("Users ID range:", users_df['user_id'].min(), "-", users_df['user_id'].max())
print("Items ID range:", items_df['user_id'].min(), "-", items_df['user_id'].max())
print("Users consistent with items:", items_df['user_id'].isin(users_df['user_id']).all(), "\n")

print("Item ID range:", items_df['item_id'].min(), "-", items_df['item_id'].max())
print("Item IDs are consecutive:", items_df['item_id'].max() == items_df['item_id'].nunique() - 1)

RANGE CONSISTENCY CHECK
Users ID range: 0 - 131986
Items ID range: 0 - 131986
Users consistent with items: True 

Item ID range: 0 - 26678
Item IDs are consecutive: True


## Train/Val/Test Split

Interactions are split per user (80% train, 10% validation, 10% test) so that users are represented across splits. This reduces distribution shifts between splits and supports a more realistic evaluation setting.

In [9]:
def split_per_user(df, train_frac=0.8, val_frac=0.1, random_state=42):
    """
    Splitting items per user by using groupby and apply.

    Args:
        df: DataFrame with user_id column
        train_frac: Fraction for training
        val_frac: Fraction for validation
        random_state: Random seed

    Returns:
        train_df, val_df, test_df
    """
    def split_user_data(user_items):
        """
        Split a single user's interactions into train/val/test sets.
        """
        user_id = user_items.name
        user_items = user_items.copy()
        user_items['user_id'] = user_id

        user_items = user_items.sample(frac=1, random_state=random_state) # Shuffle items
        num_items = len(user_items)

        num_train = int(train_frac * num_items)
        num_val = int(val_frac * num_items)

        split_labels = ['train'] * num_train + ['val'] * num_val + ['test'] * (num_items - num_train - num_val)
        user_items['split'] = split_labels

        return user_items

    # Apply splitting to each user's items
    df_with_splits = df.groupby('user_id', group_keys=False).apply(split_user_data)
    df_with_splits = df_with_splits[['user_id', 'item_id', 'label', 'split']] # Reorder columns

    train_df = df_with_splits[df_with_splits['split'] == 'train'].drop(columns=['split'])
    valid_df = df_with_splits[df_with_splits['split'] == 'val'].drop(columns=['split'])
    test_df = df_with_splits[df_with_splits['split'] == 'test'].drop(columns=['split'])

    return (
        train_df.reset_index(drop=True),
        valid_df.reset_index(drop=True),
        test_df.reset_index(drop=True)
    )

train_df, valid_df, test_df = split_per_user(items_df, random_state=RANDOM_SEED)

print("=" * 60)
print("DATASET SPLITS")
print("=" * 60)
print(f"{'Train size:':<11} {len(train_df):>5,}")
print(f"{'Valid size:':<13} {len(valid_df):>5,}")
print(f"{'Test size:':<13} {len(test_df):>5,}")
print(f"{'Total size:':<11} {len(train_df) + len(valid_df) + len(test_df):>5,} \n")

DATASET SPLITS
Train size: 5,386,464
Valid size:   618,118
Test size:    787,727
Total size: 6,792,309 



## Summary and Save Output Files

Final dataset statistics are displayed, gender labels are mapped to integers (m=0, f=1, nb=2), and all processed files are saved to CSV format, including both ordered and randomized versions of the sensitive attribute data.

In [10]:
print("=" * 60)
print("FINAL DATASET SUMMARY")
print("=" * 60)
print(f"Total users: {len(users_df)}")
print(f"Total items: {items_df['item_id'].nunique()}")
print(f"Total interactions: {len(train_df) + len(valid_df) + len(test_df):,}")

male_final = (users_df['gender'] == 'm').sum()
female_final = (users_df['gender'] == 'f').sum()
nb_final = (users_df['gender'] == 'nb').sum()

print(f"\nFinal gender distribution:")
print(f" - {'Male:':<11} {male_final:>4} ({male_final/len(users_df)*100:.1f}%)")
print(f" - {'Female:':<11} {female_final:>4} ({female_final/len(users_df)*100:.1f}%)")
print(f" - {'Non-binary:':<11} {nb_final:>4} ({nb_final/len(users_df)*100:.1f}%) \n")

print("Mapping gender labels to integers...")
gender_mapping = {'m': 0, 'f': 1, 'nb': 2}
users_df['gender'] = users_df['gender'].map(gender_mapping)

# Randomized sensitive attribute dataset
users_random = users_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)

print(f"\nSaving processed files to: {DATA_DIR.parent.parent.name}/{DATA_DIR.parent.name}/{DATA_DIR.name} \n")

users_df.to_csv(DATA_DIR / 'sensitive_attribute.csv', index=False)
users_random.to_csv(DATA_DIR / 'sensitive_attribute_random.csv', index=False)
train_df.to_csv(DATA_DIR / 'train.csv', index=False)
valid_df.to_csv(DATA_DIR / 'valid.csv', index=False)
test_df.to_csv(DATA_DIR / 'test.csv', index=False)

print("✓ All files saved successfully!")

FINAL DATASET SUMMARY
Total users: 131987
Total items: 26679
Total interactions: 6,792,309

Final gender distribution:
 - Male:       91567 (69.4%)
 - Female:     27208 (20.6%)
 - Non-binary: 13212 (10.0%) 

Mapping gender labels to integers...

Saving processed files to: Three-Class-MPR/datasets/Lastfm-360K-synthetic 

✓ All files saved successfully!
