In [None]:
import pandas as pd

from select_caption_data import (
    load_musiccaps_balanced_subset,
    LOW_QUALITY_LABELS,
    remove_by_matching_labels,
    load_musiccaps_genre_annotations,
    load_song_describer_data_for_generation,
    load_mtg_jamendo_genre_annotations,
)

In [None]:
mc_df = load_musiccaps_balanced_subset("data/musiccaps/musiccaps-public.csv")
mc_df = remove_by_matching_labels(mc_df, LOW_QUALITY_LABELS, "aspect_list")
len(mc_df)

In [None]:
# remove any ids from videos that errored out

with open("data/musiccaps/error.log", "r") as f:
    missing_ids = set(f.read().split("\n"))
if "" in missing_ids:
    missing_ids.remove("")
mc_df = mc_df[~mc_df["ytid"].isin(missing_ids)]
len(mc_df)

In [None]:
mc_genres = load_musiccaps_genre_annotations("data/musiccaps/musiccaps_preds.csv")
mc_genres = mc_genres.rename("genres")
mc_df = mc_df.merge(mc_genres, how="left", left_on=["ytid"], right_on=["name"])

mc_df.describe()

In [None]:
mc_df

In [None]:
mc_df = mc_df[
    ~mc_df["genres"].isin(
        ["Stage & Screen", "Non-Music", "Children's", "Brass & Military"]
    )
]
len(mc_df)

In [None]:
mc_df["genres"].value_counts()

In [None]:
# 648 - (78 + 17 + 11 + 3 * 15)

In [None]:
# cap all genres to 85 samples
mc_df = mc_df.groupby("genres").head(85)

In [None]:
mc_df["genres"].value_counts().plot(
    kind="bar",
    title=f"MusicCaps Genre Distribution (N={len(mc_df)})",
    xlabel="Genre",
    ylabel="Number of Samples",
)

In [None]:
# save to csv
# mc_df.to_csv("data/musiccaps/musiccaps-for-generation.csv", index=False)

In [None]:
sdd_df = load_song_describer_data_for_generation(
    "data/SongDescriberDataset/song_describer.csv",
    "data/SongDescriberDataset/music-classification-annotations-clean.tsv",
)
print(sdd_df["track_id"].nunique())
sdd_df

In [None]:
mtg_genres = load_mtg_jamendo_genre_annotations("data/mtg-jamendo-predictions.tsv")
mtg_genres

In [None]:
sdd_df = sdd_df.merge(mtg_genres, how="left", left_on="track_id", right_index=True)
sdd_df = sdd_df[
    ~sdd_df["genre"].isin(
        ["Stage & Screen", "Non-Music", "Children's", "Brass & Military"]
    )
]
print(sdd_df["track_id"].nunique())
sdd_df

In [None]:
# sdd_df.to_csv("data/SongDescriberDataset/song_describer-for-generation.csv", index=False)

In [None]:
sdd_genre_counts = sdd_df.groupby("track_id")["genre"].first().value_counts()
# sdd_genre_counts.index = sdd_genre_counts.index.str[0]
sdd_genre_counts.plot(
    kind="bar",
    title=f"SongDescriber Genre Distribution (N={len(sdd_df.groupby('track_id'))})",
    xlabel="Genre",
    ylabel="Number of Samples",
)

In [None]:
mc_genre_counts = mc_df["genres"].value_counts()

In [None]:
summed_genre_counts = pd.concat(
    [sdd_genre_counts, mc_genre_counts], axis=1, keys=["SongDescriber", "MusicCaps"]
).fillna(0)
summed_genre_counts["Total"] = summed_genre_counts.sum(axis=1)
summed_genre_counts = summed_genre_counts.sort_values("Total", ascending=False)
summed_genre_counts = summed_genre_counts.sort_index(axis=1, ascending=False)

In [None]:
summed_genre_counts.plot(
    kind="bar",
    title="Genre Distribution Comparison",
    xlabel="Genre",
    ylabel="Number of Samples",
)