<a href="https://colab.research.google.com/github/joris-vaneyghen/mss-jazz-playalong/blob/main/segmentation/cluster_segments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# download our audio example
!git clone https://github.com/joris-vaneyghen/mss-jazz-playalong.git

In [None]:
import numpy as np
import json
import os

from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt


In [None]:
input_path = 'mss-jazz-playalong/examples'
output_path = 'mss-jazz-playalong/out/segment_and_tag'
resolution = 0.32 # resolution of EfficientAT model

In [None]:
def load_json(dir, mp3_file):
    # Replace .mp3 extension with .json
    json_file_name = mp3_file.replace('.mp3', '.json')
    file_path = os.path.join(dir, json_file_name)

    # Check if the .json file exists
    if not os.path.exists(file_path):
        return {}  # Return an empty dictionary if the .json file doesn't exist

    # Load the JSON file if it exists
    with open(file_path, 'r') as file:
        return json.load(file)

# def save_json(dir, mp3_file, data):
#     # Replace .mp3 extension with .json
#     json_file_name = mp3_file.replace('.mp3', '.json')
#     file_path = os.path.join(dir, json_file_name)

#     # Check if directory exists, create it if not
#     if not os.path.exists(dir):
#         os.makedirs(dir)

#     # Save the data to the .json file
#     with open(file_path, 'w') as file:
#         json.dump(data, file, indent=4)

def iterate_files(dir):
    for file_name in os.listdir(dir):
        if file_name.endswith('.mp3'):
            yield file_name

In [None]:
segment_lengths =  []
segment_preds = []
drums = []
bass = []
other = []
vocals = []
for mp3_file in iterate_files(input_path):
  data = load_json(output_path, mp3_file)
  if ('demucs' in data.keys() and 'segments' in data.keys()):
    segments = data['segments']
    for segment in segments:
      length = segment['end_idx'] - segment['start_idx']
      segment_lengths.append(length)
      segment_preds.append(segment['preds'])
      drums.append(segment['drums'])
      bass.append(segment['bass'])
      other.append(segment['other'])
      vocals.append(segment['vocals'])



In [None]:
# plot frequency chart of segment_lengths
plt.hist(segment_lengths, bins=20)
plt.xlabel("Segment Length")
plt.ylabel("Frequency")
plt.title("Frequency Chart of Segment Lengths")
plt.show()


In [None]:
segment_preds = np.array(segment_preds)
drums = np.array(drums)
bass = np.array(bass)
other = np.array(other)
vocals = np.array(vocals)
demucs_features = np.stack((drums, bass, other, vocals), axis=1)

In [None]:
# Range of clusters to try (from 2 to 5)
cluster_range = range(8, 50)

# List to store SSE (sum of squared distances) for the elbow method
sse = []
silhouette_scores = []


pca = PCA(n_components=20)  # Reduce to 50 dimensions or fewer


combined = np.concatenate((segment_preds, demucs_features * 10 ), axis=1)
data = pca.fit_transform(combined)
data2 = pca.transform(combined)

# Perform KMeans clustering for different values of k
for k in cluster_range:
    kmeans = KMeans(n_clusters=k, random_state=42)
    kmeans.fit(data)
    sse.append(kmeans.inertia_)  # SSE for elbow method
    silhouette_avg = silhouette_score(data, kmeans.labels_)
    silhouette_scores.append(silhouette_avg)

# Plot SSE for elbow method
plt.figure(figsize=(10, 5))
plt.plot(cluster_range, sse, 'bx-')
plt.xlabel('Number of clusters (k)')
plt.ylabel('SSE (Sum of Squared Distances)')
plt.title('Elbow Method for Optimal k')
plt.show()

# Plot Silhouette Score for each k
plt.figure(figsize=(10, 5))
plt.plot(cluster_range, silhouette_scores, 'bx-')
plt.xlabel('Number of clusters (k)')
plt.ylabel('Silhouette Score')
plt.title('Silhouette Score for Different k')
plt.show()

# Choose the best k based on visual inspection of the elbow and silhouette score
best_k = cluster_range[np.argmax(silhouette_scores)]
print(f"Best number of clusters: {best_k}")

# Perform KMeans clustering with the best k
kmeans = KMeans(n_clusters=best_k, random_state=42)
labels = kmeans.fit_predict(data)

# Print cluster labels for each sample
print("Cluster labels for the data points:", labels)

unique_labels, label_counts = np.unique(labels, return_counts=True)

# Plot the frequencies of labels
plt.bar(unique_labels, label_counts)
plt.xlabel('Cluster Labels')
plt.ylabel('Frequency')
plt.title('Frequency of Cluster Labels')
plt.show()


In [None]:
# plot frequencies of labels

import matplotlib.pyplot as plt
import numpy as np



In [None]:
import pickle

# Save the PCA model to a file
def save_pca_model(pca, filename="pca_model.pkl"):
  with open(filename, "wb") as f:
    pickle.dump(pca, f)

# Load the PCA model from a file
def load_pca_model(filename="pca_model.pkl"):
  with open(filename, "rb") as f:
    return pickle.load(f)

# Project new data using the loaded PCA model
def transform_data(pca_model, new_data):
  return pca_model.transform(new_data)

def save_kmeans_model(kmeans, filename="kmeans_model.pkl"):
  with open(filename, "wb") as f:
    pickle.dump(kmeans, f)

# Load the KMeans model from a file
def load_kmeans_model(filename="kmeans_model.pkl"):
  with open(filename, "rb") as f:
    return pickle.load(f)

# Predict cluster labels for new data using the loaded KMeans model
def predict_cluster_labels(kmeans_model, new_data):
  return kmeans_model.predict(new_data)

# Example usage:

# Assuming 'pca' is your trained PCA model
save_pca_model(pca)

# Assuming 'kmeans' is your trained KMeans model
save_kmeans_model(kmeans)

# Later, to load and use the model:
loaded_pca = load_pca_model()

# Later, to load and use the model:
loaded_kmeans = load_kmeans_model()

# Example new data
new_data = np.random.rand(10, 531) # Replace with your actual new data

# Project the new data using the loaded PCA model
transformed_data = transform_data(loaded_pca, new_data)
print(transformed_data.shape)

# Predict the cluster labels for the new data
new_labels = predict_cluster_labels(loaded_kmeans, transformed_data)
print(new_labels)