<a href="https://colab.research.google.com/github/joris-vaneyghen/mss-jazz-playalong/blob/main/segmentation/demucs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install demucs -q
!git clone https://github.com/joris-vaneyghen/mss-jazz-playalong.git -q

In [None]:
#CONFIG
input_path = 'mss-jazz-playalong/examples'
output_path = 'output'
model_name = 'htdemucs_ft'
resolution = 0.32 #seconds same as resolution of EfficientAT model

In [None]:
import torch
from demucs.apply import apply_model
from demucs import pretrained
import json
import os
import torchaudio
import numpy as np

In [None]:
def load_json(dir, mp3_file):
    # Replace .mp3 extension with .json
    json_file_name = mp3_file.replace('.mp3', '.json')
    file_path = os.path.join(dir, json_file_name)

    # Check if the .json file exists
    if not os.path.exists(file_path):
        return {}  # Return an empty dictionary if the .json file doesn't exist

    # Load the JSON file if it exists
    with open(file_path, 'r') as file:
        return json.load(file)

def save_json(dir, mp3_file, data):
    # Replace .mp3 extension with .json
    json_file_name = mp3_file.replace('.mp3', '.json')
    file_path = os.path.join(dir, json_file_name)

    # Check if directory exists, create it if not
    if not os.path.exists(dir):
        os.makedirs(dir)

    # Save the data to the .json file
    with open(file_path, 'w') as file:
        json.dump(data, file, indent=4)

def iterate_files(dir):
    for file_name in os.listdir(dir):
        if file_name.endswith('.mp3'):
            yield file_name

def load_model():
  model = pretrained.get_model(model_name)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.to(device)
  return model, device

def calculate_demucs(model, device, dir, mp3_file):
  waveform, sample_rate = torchaudio.load(os.path.join(dir, mp3_file))
  if sample_rate != model.samplerate:
    waveform = torchaudio.functional.resample(waveform, sample_rate, model.samplerate)

  waveform = waveform.to(device)
  out = apply_model(model, waveform.unsqueeze(0), progress=True, device=device)
  #remove batch dim and average over Left/Right Channel
  out = out.squeeze(0).mean(dim=1)
  # Reshape the tensor into chunks
  chunk_size = int(model.samplerate * resolution)
  out_reshaped = out.unfold(dimension=1, size=chunk_size, step=chunk_size)
  # Calculate the mean along the time dimension for each chunk
  out_reduced = out_reshaped.abs().mean(dim=2)

  return out_reduced.cpu().numpy()


In [None]:
model, device = load_model()

for mp3_file in iterate_files(input_path):
  data = load_json(output_path, mp3_file)
  if ('demucs' not in data.keys()):
    print(mp3_file)
    data['demucs'] = calculate_demucs( model, device, input_path, mp3_file).tolist()
    save_json(output_path, mp3_file, data)