<a href="https://colab.research.google.com/github/joris-vaneyghen/mss-jazz-playalong/blob/main/segmentation/demucs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/facebookresearch/demucs -q
!git clone https://github.com/joris-vaneyghen/mss-jazz-playalong.git -q

In [None]:
#CONFIG
input_path = 'mss-jazz-playalong/examples'
output_path = 'output'
model_name = 'htdemucs_ft'
resolution = 0.32 #seconds same as resolution of EfficientAT model

In [None]:
import torch
import demucs.api
import json
import os
import torchaudio
import numpy as np

In [None]:
def load_json(dir, mp3_file):
    # Replace .mp3 extension with .json
    json_file_name = mp3_file.replace('.mp3', '.json')
    file_path = os.path.join(dir, json_file_name)

    # Check if the .json file exists
    if not os.path.exists(file_path):
        return {}  # Return an empty dictionary if the .json file doesn't exist

    # Load the JSON file if it exists
    with open(file_path, 'r') as file:
        return json.load(file)

def save_json(dir, mp3_file, data):
    # Replace .mp3 extension with .json
    json_file_name = mp3_file.replace('.mp3', '.json')
    file_path = os.path.join(dir, json_file_name)

    # Check if directory exists, create it if not
    if not os.path.exists(dir):
        os.makedirs(dir)

    # Save the data to the .json file
    with open(file_path, 'w') as file:
        json.dump(data, file, indent=4)

def iterate_files(dir):
    for file_name in os.listdir(dir):
        if file_name.endswith('.mp3'):
            yield file_name

def load_separator():
  separator = demucs.api.Separator(model=model_name)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  separator.update_parameter(device=device)
  separator.update_parameter(progress=True)
  return separator

def calculate_demucs(separator, dir, mp3_file):
  # Load waveform and handle resampling
  waveform, sample_rate = torchaudio.load(os.path.join(dir, mp3_file))
  if sample_rate != separator.samplerate:
      waveform = torchaudio.functional.resample(waveform, sample_rate, model.samplerate)

  # Determine the number of samples per segment
  chunk_size = int(separator.samplerate * resolution)
  max_samples_per_segment = chunk_size * max_chunks

  # Initialize list to store output segments
  output_segments = {'drums':[],'bass':[],'vocals':[],'other':[]}

  # Process each segment of the waveform separately
  num_samples = waveform.shape[1]
  for start in range(0, num_samples, max_samples_per_segment):
    end = min(start + max_samples_per_segment, num_samples)
    segment = waveform[:, start:end]

    _, seperated = separator.separate_tensor(segment)

    for key in seperated.keys():
      out = seperated[key]
      # Average over left/right channels
      out = out.mean(dim=0)
      # Reshape the tensor into chunks
      out_reshaped = out.unfold(dimension=0, size=chunk_size, step=chunk_size)
      # Calculate the mean along the time dimension for each chunk
      out_reduced = out_reshaped.abs().mean(dim=1)
      output_segments[key].extend(out_reduced.numpy().tolist())

  return output_segments

In [None]:
separator = load_separator()

for mp3_file in iterate_files(input_path):
  data = load_json(output_path, mp3_file)
  if ('demucs' not in data.keys()):
    print(mp3_file)
    data['demucs'] = calculate_demucs(separator, input_path, mp3_file)
    save_json(output_path, mp3_file, data)