# Step 1: Import Necessary Libraries

In [None]:
# We'll use `pandas` to create and manipulate the DataFrame, and `json` to save the final data in JSON format.
import pandas as pd
import json

# Step 2: Load and parse the file

In [None]:
file_path = 'msd_genre_classification.cls'  # Replace with your actual file path

# Read the file into a list of dictionaries based on the header format
data = []
with open(file_path, 'r') as file:
    for line in file:
        # Skip comments and empty lines
        if line.startswith('#') or not line.strip():
            continue
        # Split each line into fields
        fields = line.strip().split('\t')

        # Parse the fields into respective columns
        track_id = fields[0]
        seed_genre = fields[1]
        num_labels = int(fields[2])
        labels = fields[3:]

        # Group labels and their strengths
        label_data = {labels[i]: float(labels[i+1]) for i in range(0, len(labels), 2)}

        # Append data as a dictionary
        data.append({
            'track_id': track_id,
            'seed_genre': seed_genre,
            'num_labels': num_labels,
            'labels': label_data
        })

# Convert the parsed data into a DataFrame
df = pd.DataFrame(data)
print("Initial DataFrame:")
print(len(df))
df.head()

: 

# Step 3: Identify and sort genres by num_labels


In [18]:
# This will allow us to filter genres with sufficient data later.
print(len(df))
genre_counts = df['seed_genre'].value_counts()
num_genres = len(genre_counts)
num_genres_under_50 = (genre_counts < 50).sum()
num_genres_above_10000 = (genre_counts > 10_000).sum()
print(f"Number of genres: {num_genres}")
print(f"Number of genres with counts under 50: {num_genres_under_50}")
print(f"Number of genres with counts above 10,000: {num_genres_above_10000}")
print("Genre Counts:\n", genre_counts)

677038
Number of genres: 6152
Number of genres with counts under 50: 6024
Number of genres with counts above 10,000: 14
Genre Counts:
 seed_genre
Rock                      261242
Pop                        57210
Electronic                 38235
Jazz                       37844
Hip-Hop                    31580
                           ...  
Psychedlic Folk-Rock           1
Psychedelic Blues Rock         1
Suave                          1
Alternative - magyar           1
All                            1
Name: count, Length: 6152, dtype: int64


# Step 4: Filter categories with low couts or unuseal genres

In [39]:
# Threshold for "big genres"
# 10,000 = 14 Genres , 11,000 = 11 Genres, 13,000 = 9
big_genre_threshold = 10_000

# Count occurrences of each seed genre
genre_counts = df['seed_genre'].value_counts()

# Filter big genres
big_genres = genre_counts[genre_counts > big_genre_threshold].index
big_genres_df = df[df['seed_genre'].isin(big_genres)]

print(f"Number of Big Genres: {len(big_genres)}")
print(f"Filtered DataFrame rows for Big Genres: {len(big_genres_df)}")
print("Sample of Filtered DataFrame:")
print(big_genres_df.head())

# Validate the genre counts
print(f"Total genres: {len(genre_counts)}")
print(f"Genres with counts under 50: {(genre_counts < 50).sum()}")
print(f"Genres with counts above {big_genre_threshold}: {(genre_counts > big_genre_threshold).sum()}")

Number of Big Genres: 14
Filtered DataFrame rows for Big Genres: 574269
Sample of Filtered DataFrame:
             track_id seed_genre  num_labels  \
0  TRAAAAK128F9318786       Rock         201   
1  TRAAAAV128F421A322       Rock           8   
2  TRAAAAW128F429D538    Hip-Hop         133   
3  TRAAAAY128F42A73F0      World           1   
4  TRAAABD128F429CF47       Rock          40   

                                              labels  
0  {'Rock': 0.6766169, 'Metal': 0.09950249, 'Hard...  
1                         {'Rock': 0.5, 'Punk': 0.5}  
2  {'Hip-Hop': 0.48872182, 'Hip-Hop/Rap': 0.27067...  
3                                     {'World': 1.0}  
4  {'Rock': 0.4, 'Rock/Pop': 0.15, 'Oldies': 0.1,...  
Total genres: 6152
Genres with counts under 50: 6024
Genres with counts above 10000: 14


# Step 5: Remove songs with less than 10 user labels

In [40]:
# Filter to include only songs with 10 or more labels
min_song_threshold = 10
big_genres_df = big_genres_df[big_genres_df['num_labels'] >= min_song_threshold]
print(f"Filtered DataFrame with songs having 5 or more labels: {big_genres_df.shape[0]} songs")

# Remove songs with 'other' as the seed genre (if any)
explicit_genre_exclude = ['other']
big_genres_df = big_genres_df[big_genres_df['seed_genre'].isin(explicit_genre_exclude) == False]
print(f"Filtered DataFrame after removing 'other' genre: {big_genres_df.shape[0]} songs")

# Display a sample of the filtered DataFrame
print("Sample of the filtered DataFrame:")
print(big_genres_df.head())


Filtered DataFrame with songs having 5 or more labels: 339099 songs
Filtered DataFrame after removing 'other' genre: 339099 songs
Sample of the filtered DataFrame:
             track_id seed_genre  num_labels  \
0  TRAAAAK128F9318786       Rock         201   
2  TRAAAAW128F429D538    Hip-Hop         133   
4  TRAAABD128F429CF47       Rock          40   
8  TRAAAED128E0783FAB       Jazz        2227   
9  TRAAAEF128F4273421       Rock         181   

                                              labels  
0  {'Rock': 0.6766169, 'Metal': 0.09950249, 'Hard...  
2  {'Hip-Hop': 0.48872182, 'Hip-Hop/Rap': 0.27067...  
4  {'Rock': 0.4, 'Rock/Pop': 0.15, 'Oldies': 0.1,...  
8  {'Jazz': 0.6897171, 'Pop': 0.061517738, 'Gener...  
9  {'Rock': 0.45303866, 'New Wave': 0.13812155, '...  


# Step 6: Clean the labels for each songs (Remove low probability and unused genres)

In [41]:
# Define a probability threshold for keeping genres
probability_threshold = 0.01

def clean_labels(label_data: dict, threshold: float, valid_genres: list[str]):
    """
    Remove low-probability genres from the label dictionary.

    Parameters:
        label_data (dict): Dictionary of genres and their probabilities.
        threshold (float): Minimum probability to retain a genre.
        valid_genres (list[str]): List of valid genres to retain.

    Returns:
        dict: Cleaned label dictionary with high-probability genres.
    """
    return {genre: prob for genre, prob in label_data.items() if prob >= threshold and genre in valid_genres}

# Apply the cleaning function to the 'labels' column
big_genres_df['labels'] = big_genres_df['labels'].apply(
    clean_labels,
    threshold=probability_threshold,
    valid_genres=list(big_genres)
)

# Remove rows where the cleaned 'labels' dictionary is empty (i.e., no genres meet the threshold)
big_genres_df = big_genres_df[big_genres_df['labels'].apply(bool)]

print(f"DataFrame after cleaning labels: {big_genres_df.shape[0]} songs")
print("Sample of cleaned labels:")
print(big_genres_df[['track_id', 'labels']].head())


DataFrame after cleaning labels: 327066 songs
Sample of cleaned labels:
             track_id                                   labels
0  TRAAAAK128F9318786                      {'Rock': 0.6766169}
2  TRAAAAW128F429D538                  {'Hip-Hop': 0.48872182}
4  TRAAABD128F429CF47              {'Rock': 0.4, 'Pop': 0.075}
8  TRAAAED128E0783FAB  {'Jazz': 0.6897171, 'Pop': 0.061517738}
9  TRAAAEF128F4273421   {'Rock': 0.45303866, 'Pop': 0.0718232}


In [2]:
import os
import json
import shutil
import pretty_midi

# Paths
aligned_folder = '1lakh_song_dataset_original'
match_scores_path = 'match_scores.json'
destination_folder = '1lakh_song_dataset_cleaned'

ModuleNotFoundError: No module named 'pretty_midi'

# Step 7: Save the cleaned data to JSON and .cls file

In [42]:
# We'll save the final DataFrame in two formats as requested.
json_output_path = 'cleaned_data.json'
with open(json_output_path, 'w') as json_file:
    json.dump(big_genres_df.to_dict(orient='records'), json_file, indent=2)
print(f"Data saved to {json_output_path}")

# CLS format (tab-separated as per original format)
cls_output_path = 'cleaned_data.cls'
with open(cls_output_path, 'w') as cls_file:
    for _, row in big_genres_df.iterrows():
        line = f"{row['track_id']}\t{row['seed_genre']}\t{row['num_labels']}"
        for label, strength in row['labels'].items():
            line += f"\t{label}\t{strength}"
        cls_file.write(line + '\n')
print(f"Data saved to {cls_output_path}")

Data saved to cleaned_data.json
Data saved to cleaned_data.cls
