<a href="https://colab.research.google.com/github/mirchandani-mohnish/Musformer/blob/main/Demucs_Rebuild.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
jakerr5280_musdb18_music_source_separation_dataset_path = kagglehub.dataset_download('jakerr5280/musdb18-music-source-separation-dataset')

print('Data source import complete.')


In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/musdb18-music-source-separation-dataset/The Long Wait - Dark Horses.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/Raft Monk - Tiring.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/BKS - Too Much.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/Georgia Wonder - Siren.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/The Sunshine Garcia Band - For I Am The Moon.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/Enda Reilly - Cur An Long Ag Seol.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/Buitraker - Revo X.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/We Fell From The Sky - Not You.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/The Mountaineering Club - Mallory.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset/test/Skelpolu - Resurrection.stem.mp4
/kaggle/input/musdb18-music-source-separation-dataset

In [None]:
!pip install musdb
!pip install mir_eval


%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH

# Clone SpeechBrain repository
!git clone https://github.com/speechbrain/speechbrain/

In [None]:
db_path = '/kaggle/input/musdb18-music-source-separation-dataset'
output_path = '/kaggle/working'

In [None]:
import os
import numpy as np
np.float_ = np.float64
import musdb

MUS_DB_PATH = db_path

mus = musdb.DB(root=MUS_DB_PATH)
mus_train = musdb.DB(root=MUS_DB_PATH,subsets="train", split="train")
mus_valid = musdb.DB(root=MUS_DB_PATH,subsets="train", split="valid")
mus_test = musdb.DB(root=MUS_DB_PATH,subsets="test")
print(mus_train[0])
print(mus_test[0])

In [None]:
%%file models.py

import torch
from torch import nn
from speechbrain.nnet.CNN import Conv1d, ConvTranspose1d
# from speechbrain.nnet.activations import GLU
from speechbrain.lobes.models.beats import GLU_Linear
from torch.nn import GLU
from speechbrain.nnet.RNN import LSTM
from speechbrain.nnet.linear import Linear

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = Conv1d(
            out_channels=out_channels,
            in_channels=in_channels,
            kernel_size=8,
            stride=4,
            # default_padding=2,
            skip_transpose=True,
        )
        self.glu_conv = Conv1d(
            out_channels=2*out_channels,
            in_channels=out_channels,
            kernel_size=1,
            stride=1,
            skip_transpose=True,
        )
        self.relu = torch.nn.ReLU()
        self.glu = GLU(dim=1)

    def forward(self, x):
        # print(x.size())
        x = self.relu(self.conv(x))
        # print(x.size())
        x = self.glu(self.glu_conv(x))
        # print(x.size())
        return x

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.glu_conv = Conv1d(
            out_channels=2*in_channels,
            in_channels=in_channels,
            kernel_size=1,
            stride=1,
            skip_transpose=True,
        )
        self.conv_tr = ConvTranspose1d(
            out_channels=out_channels,
            in_channels=in_channels,  # After GLU split
            kernel_size=8,
            stride=4,
            # padding=2,
            # output_padding=2,
            skip_transpose=True,
        )
        self.glu = GLU(dim=1)
        self.relu = torch.nn.ReLU()

    def forward(self, x, skip=None):

        if skip is not None:

            # T changed after conv1d in encoder, fix it here
            T_x = x.size(-1)
            T_skip = skip.size(-1)

            # Case 1: Decoder output is longer
            if T_skip > T_x:
                # Center-trim decoder output
                start = (T_skip - T_x) // 2
                skip = skip[..., start : start + T_x]

            # Case 2: Skip is longer
            elif T_skip < T_x:
                # Center-pad decoder output
                pad = T_x - T_skip
                skip = nn.functional.pad(skip, (pad // 2, pad - pad // 2))


            x = x + skip




        x = self.glu(self.glu_conv(x))

        x = self.relu(self.conv_tr(x))

        return x




class SourceSeparator(nn.Module):
    def __init__(self, in_channels, out_channels=2, num_sources=4):
        """
        Args:
            C_in: Input channels from last decoder (typically 8)
            C_out: Output channels per source (2 for stereo)
            num_sources: Number of sources to separate (e.g. 4 for vocals, drums, bass, other)
        """
        super().__init__()
        # Final linear layer (no activation)
        self.output_proj = Linear(
            input_size=in_channels,
            n_neurons=num_sources * out_channels,  # 4 sources * 2 channels = 8
            bias=True
        )
        self.num_sources = num_sources
        self.out_channels = out_channels

    def forward(self, x):
        """
        Input: [batch, C_in, time]
        Output: [batch, num_sources, out_channels, time]
        """
        # Permute to [batch, time, features]
        # print(x.size())
        x = x.permute(0, 2, 1)

        # Project to source waveforms
        x = self.output_proj(x)  # [batch, time, num_sources*out_channels]
        # print(x.size())

        # Reshape to separated sources
        x = x.view(x.size(0), -1, self.num_sources, self.out_channels)
        x = x.permute(0, 2, 1, 3)
        # print(x.size())
        # Return as [batch, sources, channels, time]
        return x

In [None]:
%%file hparams.yaml

# ################################
# Model: Demucs for source separation
# https://hal.science/hal-02379796/document
# Dataset : Musdb
# ################################
# Basic parameters
seed: 1234
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]

# Data params (unchanged from DPRNN)
data_folder: !PLACEHOLDER

experiment_name: demucs
output_folder: !ref /kaggle/working/results/<experiment_name>/<seed>
train_log: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save
train_data: !ref <output_folder>/train.json
valid_data: !ref <output_folder>/valid.json
test_data: !ref <output_folder>/test.json
skip_prep: False
db_path: '/kaggle/input/musdb18-music-source-separation-dataset'


# Experiment params
precision: fp16
num_sources: 2

instrumental_classification: False
noprogressbar: False
save_audio: True
sample_rate: 16000
n_audio_to_save: 10

####################### Training Parameters ####################################

N_epochs: 3
batch_size: 2
lr: 0.00015
clip_grad_norm: 5
loss_upper_lim: 999999
limit_training_signal_len: False
training_signal_len: 32000000


# Data augmentation (unchanged)
use_wavedrop: False
use_rand_shift: False
min_shift: -8000
max_shift: 8000


# Frequency/time drop (unchanged)
drop_freq: !new:speechbrain.augment.time_domain.DropFreq
    drop_freq_low: 0
    drop_freq_high: 1
    drop_freq_count_low: 1
    drop_freq_count_high: 3
    drop_freq_width: 0.05

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
    drop_length_low: 1000
    drop_length_high: 2000
    drop_count_low: 1
    drop_count_high: 5

threshold_byloss: True
threshold: -30

################ Demucs Specific Parameters #############################
## for Demucs V3/4
# # Fourier Transform Parameters
# n_fft: 2048
# hop_length: 512


kernel_size: 16
# kernel_stride: 8

# Dataloader options (unchanged)
dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: 3

######################## Network Definition ####################################


Encoder1: !new:models.EncoderBlock
    in_channels: 2
    # kernel_size: !ref <kernel_size>
    out_channels: 64


Encoder2: !new:models.EncoderBlock
    in_channels: 64
    # kernel_size: !ref <kernel_size>
    out_channels: 128


Encoder3: !new:models.EncoderBlock
    in_channels: 128
    # kernel_size: !ref <kernel_size>
    out_channels: 256


Encoder4: !new:models.EncoderBlock
    in_channels: 256
    # kernel_size: !ref <kernel_size>
    out_channels: 512


Encoder5: !new:models.EncoderBlock
    in_channels: 512
    # kernel_size: !ref <kernel_size>
    out_channels: 1024


Encoder6: !new:models.EncoderBlock
    in_channels: 1024
    # kernel_size: !ref <kernel_size>
    out_channels: 2048




Decoder6: !new:models.DecoderBlock
    in_channels: 2048
    out_channels: 1024
    # # kernel_size: !ref <kernel_size>
    # stride: !ref <kernel_stride>


Decoder5: !new:models.DecoderBlock
    in_channels: 1024
    out_channels: 512
    # # kernel_size: !ref <kernel_size>
    # stride: !ref <kernel_stride>


Decoder4: !new:models.DecoderBlock
    in_channels: 512
    out_channels: 256
    # kernel_size: !ref <kernel_size>
    # stride: !ref <kernel_stride>


Decoder3: !new:models.DecoderBlock
    in_channels: 256
    out_channels: 128
    # kernel_size: !ref <kernel_size>
    # stride: !ref <kernel_stride>


Decoder2: !new:models.DecoderBlock
    in_channels: 128
    out_channels: 64
    # kernel_size: !ref <kernel_size>
    # stride: !ref <kernel_stride>


Decoder1: !new:models.DecoderBlock
    in_channels: 64
    out_channels: 8
    # kernel_size: !ref <kernel_size>
    # stride: !ref <kernel_stride>


Linear: !new:speechbrain.nnet.linear.Linear
    input_size: 4096
    bias: False
    n_neurons: 2048

BiLSTM: !new:speechbrain.nnet.RNN.LSTM
    hidden_size: 2048
    input_size: 2048
    num_layers: 2
    bidirectional: True
    # batch_first: True

LinearSeparator: !new:models.SourceSeparator
    in_channels: 8
    out_channels: 2
    num_sources: !ref <num_sources>


######################## Remaining Config ######################################
optimizer: !name:torch.optim.Adam
    lr: !ref <lr>
    weight_decay: 0

# loss: !name:speechbrain.nnet.losses.mse_loss
loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper


lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
    factor: 0.5
    patience: 2
    dont_halve_until_epoch: 50

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <N_epochs>

modules:
    encoder1: !ref <Encoder1>
    encoder2: !ref <Encoder2>
    encoder3: !ref <Encoder3>
    encoder4: !ref <Encoder4>
    encoder5: !ref <Encoder5>
    encoder6: !ref <Encoder6>
    lstm: !ref <BiLSTM>
    linear: !ref <Linear>
    decoder6: !ref <Decoder6>
    decoder5: !ref <Decoder5>
    decoder4: !ref <Decoder4>
    decoder3: !ref <Decoder3>
    decoder2: !ref <Decoder2>
    decoder1: !ref <Decoder1>
    linearSeparator: !ref <LinearSeparator>

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        encoder1: !ref <Encoder1>
        encoder2: !ref <Encoder2>
        encoder3: !ref <Encoder3>
        encoder4: !ref <Encoder4>
        encoder5: !ref <Encoder5>
        encoder6: !ref <Encoder6>
        lstm: !ref <BiLSTM>
        linear: !ref <Linear>
        decoder6: !ref <Decoder6>
        decoder5: !ref <Decoder5>
        decoder4: !ref <Decoder4>
        decoder3: !ref <Decoder3>
        decoder2: !ref <Decoder2>
        decoder1: !ref <Decoder1>
        linearSeparator: !ref <LinearSeparator>
        counter: !ref <epoch_counter>


train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

In [None]:
%%file train.py
#!/usr/bin/env/python3
"""Recipe for training a neural speech separation system on the wsjmix
dataset. The system employs an encoder, a decoder, and a masking network.

To run this recipe, do the following:
> python train.py hparams/sepformer.yaml
> python train.py hparams/dualpath_rnn.yaml
> python train.py hparams/convtasnet.yaml

The experiment file is flexible enough to support different neural
networks. By properly changing the parameter files, you can try
different architectures. The script supports both wsj2mix and
wsj3mix.


Authors
 * Cem Subakan 2020
 * Mirco Ravanelli 2020
 * Samuele Cornell 2020
 * Mirko Bronzi 2020
 * Jianyuan Zhong 2020
"""
## CHECKPOINT
import csv
import os
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm

import speechbrain as sb
import speechbrain.nnet.schedulers as schedulers
from speechbrain.utils.distributed import run_on_main
from speechbrain.utils.logger import get_logger
from speechbrain.nnet.CNN import Conv1d, ConvTranspose1d
# from speechbrain.nnet.activations import GLU
from speechbrain.lobes.models.beats import GLU_Linear
from torch.nn import GLU
from speechbrain.nnet.RNN import LSTM
from speechbrain.nnet.linear import Linear
from models import EncoderBlock, DecoderBlock
from speechbrain.nnet.losses import get_si_snr_with_pitwrapper

from torch.utils.data import Dataset
import musdb




# Define training procedure
class DemucsSeparation(sb.Brain):
    # def on_fit_start(self):


    def compute_forward(self, mix, targets, stage, noise=None):
        """Forward computations from the mixture to the separated signals."""

        # Unpack lists and put tensors in the right device
        mix, mix_lens = mix
        mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)

        # Convert targets to tensor
        # print([targets[i][0].size() for i in range(self.hparams.num_sources)])
        targets = torch.cat(
            [targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_sources)],
            dim=-1,
        ).to(self.device)



        mix_enc_1 = self.modules.encoder1(mix)

        mix_enc_2 = self.modules.encoder2(mix_enc_1)

        mix_enc_3 = self.modules.encoder3(mix_enc_2)
        mix_enc_4 = self.modules.encoder4(mix_enc_3)
        mix_enc_5 = self.modules.encoder5(mix_enc_4)
        mix_enc_6 = self.modules.encoder6(mix_enc_5)

        lstm_in = mix_enc_6.permute(0,2,1)
        lstm_out, _ = self.modules.lstm(lstm_in) # outputs both -- outputs as well as hidden states -- we dont need hidden states
        # print(lstm_out.size())
        lin_out = self.modules.linear(lstm_out)
        # print(lin_out.size())
        lin_out = lin_out.permute(0,2,1)

        mix_dec_6 = self.modules.decoder6(lin_out, skip=mix_enc_6)
        mix_dec_5 = self.modules.decoder5(mix_dec_6, skip=mix_enc_5)
        mix_dec_4 = self.modules.decoder4(mix_dec_5, skip=mix_enc_4)
        mix_dec_3 = self.modules.decoder3(mix_dec_4, skip=mix_enc_3)
        mix_dec_2 = self.modules.decoder2(mix_dec_3, skip=mix_enc_2)
        mix_dec_1 = self.modules.decoder1(mix_dec_2, skip=mix_enc_1)

        mix_out = self.modules.linearSeparator(mix_dec_1)


        est_source = mix_out



        # T changed after conv1d in encoder, fix it here
        T_origin = targets.size(2)
        T_est = est_source.size(2)

        if T_origin > T_est:
            est_source = F.pad(est_source, (0, 0, T_origin - T_est, 0))
        else:
            est_source = est_source[:, : , :T_origin, :]


        return est_source, targets

    def compute_objectives(self, predictions, targets):
        """Computes the sinr loss"""

        return self.hparams.loss(targets.squeeze(dim=1), predictions.squeeze(dim=1)) # for pitwrapper
        # return self.hparams.loss(targets=targets, predictions=predictions)
## CHECKPOINT
    def fit_batch(self, batch):
        """Trains one batch"""

        # Unpacking batch list
        mixture = batch.mix_sig
        targets = [batch.voc_sig, batch.inst_sig]

        with self.training_ctx:
            predictions, targets = self.compute_forward(
                mixture, targets, sb.Stage.TRAIN
            )

            loss = self.compute_objectives(predictions, targets)

            # hard threshold the easy dataitems
            if self.hparams.threshold_byloss:
                th = self.hparams.threshold
                loss = loss[loss > th]
                if loss.nelement() > 0:
                    loss = loss.mean()
            else:
                loss = loss.mean()

        if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
            self.scaler.scale(loss).backward()
            if self.hparams.clip_grad_norm >= 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self.modules.parameters(),
                    self.hparams.clip_grad_norm,
                )
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.nonfinite_count += 1
            logger.info(
                "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                    self.nonfinite_count
                )
            )
            loss.data = torch.tensor(0.0).to(self.device)
        self.optimizer.zero_grad()

        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        snt_id = batch.track_id
        mixture = batch.mix_sig
        targets = [batch.voc_sig, batch.inst_sig]


        with torch.no_grad():
            predictions, targets = self.compute_forward(mixture, targets, stage)
            loss = self.compute_objectives(predictions, targets)

        # Manage audio file saving
        if stage == sb.Stage.TEST and self.hparams.save_audio:
            if hasattr(self.hparams, "n_audio_to_save"):
                if self.hparams.n_audio_to_save > 0:
                    self.save_audio(snt_id[0], mixture, targets, predictions)
                    self.hparams.n_audio_to_save += -1
            else:
                self.save_audio(snt_id[0], mixture, targets, predictions)

        return loss.mean().detach()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {"si-snr": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            # Learning rate annealing
            if isinstance(
                self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
            ):
                current_lr, next_lr = self.hparams.lr_scheduler(
                    [self.optimizer], epoch, stage_loss
                )
                schedulers.update_learning_rate(self.optimizer, next_lr)
            else:
                # if we do not use the reducelronplateau, we do not change the lr
                current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]

            self.hparams.train_logger.log_stats(
                stats_meta={"epoch": epoch, "lr": current_lr},
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"si-snr": stage_stats["si-snr"]}, min_keys=["si-snr"]
            )
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )


    def save_results(self, test_data):
        """This script computes the SDR and SI-SNR metrics and saves
        them into a csv file"""

        # This package is required for SDR computation
        from mir_eval.separation import bss_eval_sources

        # Create folders where to store audio
        save_file = os.path.join(self.hparams.output_folder, "test_results.csv")

        # Variable init
        all_sdrs = []
        all_sdrs_i = []
        all_sisnrs = []
        all_sisnrs_i = []
        csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]

        test_loader = sb.dataio.dataloader.make_dataloader(
            test_data, **self.hparams.dataloader_opts
        )

        with open(save_file, "w", newline="", encoding="utf-8") as results_csv:
            writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
            writer.writeheader()

            # Loop over all test sentence
            with tqdm(test_loader, dynamic_ncols=True) as t:
                for i, batch in enumerate(t):
                    # Apply Separation
                    mixture, mix_len = batch.mix_sig
                    snt_id = batch.track_id
                    targets = [batch.voc_sig, batch.inst_sig]


                    with torch.no_grad():
                        predictions, targets = self.compute_forward(
                            batch.mix_sig, targets, sb.Stage.TEST
                        )

                    # Compute SI-SNR
                    predictions = predictions.permute(0,2,1,3).squeeze(dim=-1)
                    targets = targets.permute(0,2,1,3).squeeze(dim=-1)
                    # print(predictions.size())
                    # print(targets.size())
                    sisnr = get_si_snr_with_pitwrapper(predictions, targets)

                    # Compute SI-SNR improvement
                    mixture_signal = torch.stack(
                        [mixture] * self.hparams.num_sources, dim=-1
                    ).permute(0,2,1,3).squeeze(dim=-1)
                    # print("---------------------")
                    # print(mixture.size())
                    # print(mixture_signal.size())

                    mixture_signal = mixture_signal.to(targets.device)
                    sisnr_baseline = get_si_snr_with_pitwrapper(
                        mixture_signal, targets
                    )
                    sisnr_i = sisnr - sisnr_baseline

                    # Compute SDR
                    sdr, _, _, _ = bss_eval_sources(
                        targets[0].mean(dim=1).t().cpu().numpy(),
                        predictions[0].mean(dim=1).t().detach().cpu().numpy(),
                    )

                    sdr_baseline, _, _, _ = bss_eval_sources(
                        targets[0].mean(dim=1).t().cpu().numpy(),
                        mixture_signal[0].mean(dim=1).t().detach().cpu().numpy(),
                    )

                    sdr_i = sdr.mean() - sdr_baseline.mean()

                    # Saving on a csv file
                    row = {
                        "snt_id": snt_id[0],
                        "sdr": sdr.mean(),
                        "sdr_i": sdr_i,
                        "si-snr": -sisnr.item(),
                        "si-snr_i": -sisnr_i.item(),
                    }
                    writer.writerow(row)

                    # Metric Accumulation
                    all_sdrs.append(sdr.mean())
                    all_sdrs_i.append(sdr_i.mean())
                    all_sisnrs.append(-sisnr.item())
                    all_sisnrs_i.append(-sisnr_i.item())

                row = {
                    "snt_id": "avg",
                    "sdr": np.array(all_sdrs).mean(),
                    "sdr_i": np.array(all_sdrs_i).mean(),
                    "si-snr": np.array(all_sisnrs).mean(),
                    "si-snr_i": np.array(all_sisnrs_i).mean(),
                }
                writer.writerow(row)

        logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
        logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
        logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
        logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))

    def save_audio(self, snt_id, mixture, targets, predictions):
        "saves the test audio (mixture, targets, and estimated sources) on disk"

        # Create output folder

        save_path = os.path.join(self.hparams.save_folder, "audio_results")
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        for ns in range(self.hparams.num_sources):
            # Estimated source
            signal = predictions[0,: , :, ns]
            signal = signal / signal.abs().max()
            print(signal.size())
            save_file = os.path.join(
                save_path, "item{}_source{}hat.wav".format(snt_id, ns + 1)
            )
            torchaudio.save(
                save_file, signal.cpu(), self.hparams.sample_rate
            )

            # Original source
            signal = targets[0, :, : , ns]
            signal = signal / signal.abs().max()
            save_file = os.path.join(
                save_path, "item{}_source{}.wav".format(snt_id, ns + 1)
            )
            torchaudio.save(
                save_file, signal.cpu(), self.hparams.sample_rate
            )

        # Mixture
        signal = mixture[0][0, :]
        signal = signal / signal.abs().max()
        print(signal.size())
        save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
        torchaudio.save(
            save_file, signal.cpu(), self.hparams.sample_rate
        )



if __name__ == "__main__":
    # Load hyperparameters file with command-line overrides
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
    with open(hparams_file, encoding="utf-8") as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    # Initialize ddp (useful only for multi-GPU DDP training)
    sb.utils.distributed.ddp_init_group(run_opts)

    # Logger info
    logger = get_logger(__name__)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    # Update precision to bf16 if the device is CPU and precision is fp16
    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
        hparams["precision"] = "bf16"



        # Usage with SpeechBrain
    train_data = LazyMusDBDataset(hparams["db_path"], subset="train", split="train", target_sr=hparams["sample_rate"], chunk_size=30)
    valid_data = LazyMusDBDataset(hparams["db_path"], subset="train", split="valid", target_sr=hparams["sample_rate"], chunk_size=30)
    test_data = LazyMusDBDataset(hparams["db_path"], subset="test", target_sr=hparams["sample_rate"])


    # Create DataLoader
    train_loader = sb.dataio.dataloader.make_dataloader(
        train_data,
        batch_size=1,
        collate_fn=sb.dataio.batch.PaddedBatch  # Handles variable lengths
    )

    valid_loader = sb.dataio.dataloader.make_dataloader(
        valid_data,
        batch_size=1,
        collate_fn=sb.dataio.batch.PaddedBatch  # Handles variable lengths
    )


    # Brain class initialization
    separator = DemucsSeparation(
        modules=hparams["modules"],
        opt_class=hparams["optimizer"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )


    # Training
    separator.fit(
        separator.hparams.epoch_counter,
        train_loader,
        valid_loader,
        train_loader_kwargs=hparams["dataloader_opts"],
        valid_loader_kwargs=hparams["dataloader_opts"],
    )

    # Eval
    separator.evaluate(test_data, min_key="si-snr")
    separator.save_results(test_data)
    ## CHECKPOINT