<a href="https://colab.research.google.com/github/fadelmuli/mss-sepformer/blob/main/sepformer_4_sources.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install speechbrain
!pip install musdb
!pip install asteroid

In [None]:
import speechbrain as sb
from speechbrain.lobes.models.dual_path import SepformerWrapper

import torch
import torch.nn as nn
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import musdb

import librosa
from asteroid.data import MUSDB18Dataset

from pathlib import Path

import os
import numpy as np
import matplotlib.pyplot as plt

os.environ['MUSDB_PATH'] = '/content'

# Load Data

## mp4

In [None]:
mus_train = musdb.DB(download=True, subsets="train", split="train", sample_rate=8000)
mus_val = musdb.DB(download=True, subsets="train", split="valid")
mus_test = musdb.DB(download=True, subsets="test")

train_list = [track for track in mus_train]
val_list = [track for track in mus_val]
test_list = [track for track in mus_test]

train_mixs = [torch.mean(torch.tensor(datai.audio).squeeze(), dim=1).float() for datai in train_list]
train_vocals = [torch.mean(torch.tensor(datai.targets['vocals'].audio).squeeze(), dim=1).float() for datai in train_list]
train_bass = [torch.mean(torch.tensor(datai.targets['bass'].audio).squeeze(), dim=1).float() for datai in train_list]
train_drums = [torch.mean(torch.tensor(datai.targets['drums'].audio).squeeze(), dim=1).float() for datai in train_list]
train_other = [torch.mean(torch.tensor(datai.targets['other'].audio).squeeze(), dim=1).float() for datai in train_list]

val_mixs = [torch.mean(torch.tensor(datai.audio).squeeze(), dim=1).float() for datai in val_list]
val_vocals = [torch.mean(torch.tensor(datai.targets['vocals'].audio).squeeze(), dim=1).float() for datai in val_list]
val_bass = [torch.mean(torch.tensor(datai.targets['bass'].audio).squeeze(), dim=1).float() for datai in val_list]
val_drums = [torch.mean(torch.tensor(datai.targets['drums'].audio).squeeze(), dim=1).float() for datai in val_list]
val_other = [torch.mean(torch.tensor(datai.targets['other'].audio).squeeze(), dim=1).float() for datai in val_list]

test_mixs = [torch.mean(torch.tensor(datai.audio).squeeze(), dim=1).float() for datai in test_list]
test_vocals = [torch.mean(torch.tensor(datai.targets['vocals'].audio).squeeze(), dim=1).float() for datai in test_list]
test_bass = [torch.mean(torch.tensor(datai.targets['bass'].audio).squeeze(), dim=1).float() for datai in test_list]
test_drums = [torch.mean(torch.tensor(datai.targets['drums'].audio).squeeze(), dim=1).float() for datai in test_list]
test_other = [torch.mean(torch.tensor(datai.targets['other'].audio).squeeze(), dim=1).float() for datai in test_list]

print(len(train_mixs))
print(len(train_vocals))
print(len(train_bass))
print(len(train_drums))
print(len(train_other))

print(len(val_mixs))
print(len(val_vocals))
print(len(val_bass))
print(len(val_drums))
print(len(val_other))

print(len(test_mixs))
print(len(test_vocals))
print(len(test_bass))
print(len(test_drums))
print(len(test_other))

Downloading MUSDB 7s Sample Dataset to /root/MUSDB18/MUSDB18-7...
Done!
80
80
80
80
80
14
14
14
14
14
50
50
50
50
50


In [None]:
torch.save(train_mixs, '/content/drive/MyDrive/MusDB/train_mixs.pt')
torch.save(train_vocals, '/content/drive/MyDrive/MusDB/train_vocals.pt')
torch.save(train_bass, '/content/drive/MyDrive/MusDB/train_bass.pt')
torch.save(train_drums, '/content/drive/MyDrive/MusDB/train_drums.pt')
torch.save(train_other, '/content/drive/MyDrive/MusDB/train_other.pt')

torch.save(val_mixs, '/content/drive/MyDrive/MusDB/val_mixs.pt')
torch.save(val_vocals, '/content/drive/MyDrive/MusDB/val_vocals.pt')
torch.save(val_bass, '/content/drive/MyDrive/MusDB/val_bass.pt')
torch.save(val_drums, '/content/drive/MyDrive/MusDB/val_drums.pt')
torch.save(val_other, '/content/drive/MyDrive/MusDB/val_other.pt')

torch.save(test_mixs, '/content/drive/MyDrive/MusDB/test_mixs.pt')
torch.save(test_vocals, '/content/drive/MyDrive/MusDB/test_vocals.pt')
torch.save(test_bass, '/content/drive/MyDrive/MusDB/test_bass.pt')
torch.save(test_drums, '/content/drive/MyDrive/MusDB/test_drums.pt')
torch.save(test_other, '/content/drive/MyDrive/MusDB/test_other.pt')

In [None]:
train_mixs = torch.load('/content/drive/MyDrive/MusDB/train_mixs.pt')
train_vocals = torch.load('/content/drive/MyDrive/MusDB/train_vocals.pt')
train_bass = torch.load('/content/drive/MyDrive/MusDB/train_bass.pt')
train_drums = torch.load('/content/drive/MyDrive/MusDB/train_drums.pt')
train_other = torch.load('/content/drive/MyDrive/MusDB/train_other.pt')

val_mixs = torch.load('/content/drive/MyDrive/MusDB/val_mixs.pt')
val_vocals = torch.load('/content/drive/MyDrive/MusDB/val_vocals.pt')
val_bass = torch.load('/content/drive/MyDrive/MusDB/val_bass.pt')
val_drums = torch.load('/content/drive/MyDrive/MusDB/val_drums.pt')
val_other = torch.load('/content/drive/MyDrive/MusDB/val_other.pt')

test_mixs = torch.load('/content/drive/MyDrive/MusDB/test_mixs.pt')
test_vocals = torch.load('/content/drive/MyDrive/MusDB/test_vocals.pt')
test_bass = torch.load('/content/drive/MyDrive/MusDB/test_bass.pt')
test_drums = torch.load('/content/drive/MyDrive/MusDB/test_drums.pt')
test_other = torch.load('/content/drive/MyDrive/MusDB/test_other.pt')

In [None]:
train_mixs = torch.load('/content/drive/MyDrive/MusDB/train_mixs_500.pt')
train_vocals = torch.load('/content/drive/MyDrive/MusDB/train_vocals_500.pt')
train_bass = torch.load('/content/drive/MyDrive/MusDB/train_bass_500.pt')
train_drums = torch.load('/content/drive/MyDrive/MusDB/train_drums_500.pt')
train_other = torch.load('/content/drive/MyDrive/MusDB/train_other_500.pt')

val_mixs = torch.load('/content/drive/MyDrive/MusDB/val_mixs_500.pt')
val_vocals = torch.load('/content/drive/MyDrive/MusDB/val_vocals_500.pt')
val_bass = torch.load('/content/drive/MyDrive/MusDB/val_bass_500.pt')
val_drums = torch.load('/content/drive/MyDrive/MusDB/val_drums_500.pt')
val_other = torch.load('/content/drive/MyDrive/MusDB/val_other_500.pt')

test_mixs = torch.load('/content/drive/MyDrive/MusDB/test_mixs_500.pt')
test_vocals = torch.load('/content/drive/MyDrive/MusDB/test_vocals_500.pt')
test_bass = torch.load('/content/drive/MyDrive/MusDB/test_bass_500.pt')
test_drums = torch.load('/content/drive/MyDrive/MusDB/test_drums_500.pt')
test_other = torch.load('/content/drive/MyDrive/MusDB/test_other_500.pt')

In [None]:
class source_separation_dataset(Dataset):
    def __init__(self, train_mixs, train_vocals, train_bass, train_drums, train_other):
        self.mixs = train_mixs
        self.train_vocals = train_vocals
        self.train_bass = train_bass
        self.train_drums = train_drums
        self.train_other = train_other

    def __len__(self):
        return len(self.mixs)

    def __getitem__(self, idx):
        mix = self.mixs[idx]
        vocals = self.train_vocals[idx]
        bass = self.train_bass[idx]
        drums = self.train_drums[idx]
        other = self.train_other[idx]

        sources = torch.cat([vocals.unsqueeze(-1), bass.unsqueeze(-1), drums.unsqueeze(-1), other.unsqueeze(-1)], dim=1)
        mask = (sources != 0).any(0).float()

        return mix, vocals, bass, drums, other, mask

train_dataset_audio = source_separation_dataset(train_mixs, train_vocals, train_bass, train_drums, train_other)
valid_dataset_audio = source_separation_dataset(val_mixs, val_vocals, val_bass, val_drums, val_other)
test_dataset_audio = source_separation_dataset(test_mixs, test_vocals, test_bass, test_drums, test_other)

train_loader_audio = DataLoader(train_dataset_audio, batch_size=1)
valid_loader_audio = DataLoader(valid_dataset_audio, batch_size=1)
test_loader_audio = DataLoader(test_dataset_audio, batch_size=1)

In [None]:
def sdr_objective(estimation, origin, mask=None):
    """
    Scale-invariant signal-to-noise ratio (SI-SNR) loss
    Arguments:
        estimation {torch.tensor} -- separated signal of shape: (B, 4, 1, T)
        origin {torch.tensor} -- ground-truth separated signal of shape (B, 4, 1, T)
    Keyword Arguments:
        mask {torch.tensor, None} -- boolean mask: True when $origin is 0.0; shape (B, 4, 1) (default: {None})
    Returns:
        torch.tensor -- SI-SNR loss of shape: (4)
    """
  
    origin_power = torch.pow(origin, 2).sum(dim=-1, keepdim=True) + 1e-8  # shape: (B, 4, 1, 1)
    scale = torch.sum(origin*estimation, dim=-1, keepdim=True) / origin_power  # shape: (B, 4, 1, 1)

    est_true = scale * origin  # shape: (B, 4, 1, T)
    est_res = estimation - est_true  # shape: (B, 4, 1, T)

    true_power = torch.pow(est_true, 2).sum(dim=-1).clamp(min=1e-8)  # shape: (B, 4, 1)
    res_power = torch.pow(est_res, 2).sum(dim=-1).clamp(min=1e-8)  # shape: (B, 4, 1)

    sdr = 10*(torch.log10(true_power) - torch.log10(res_power))  # shape: (B, 4, 1)

    if mask is not None:
        sdr = (sdr*mask).sum(dim=-1) / mask.sum(dim=-1).clamp(min=1e-8)  # shape: (4)
    else:
        sdr = sdr.mean(dim=-1)  # shape: (4)

    return sdr  # shape: (4)

def dissimilarity_loss(latents, mask):
    """
    Minimize the similarity between the different instrument latent representations
    Arguments:
        latents {torch.tensor} -- latent matrix from the encoder of shape: (B, 1, T', N)
        mask {torch.tensor} -- boolean mask: True when the signal is 0.0; shape (B, 4)
    Returns:
        torch.tensor -- shape: ()
    """
    a_i = (0, 0, 0, 1, 1, 2)
    b_i = (1, 2, 3, 2, 3, 3)

    a = latents[a_i, :, :, :]
    b = latents[b_i, :, :, :]

    count = (mask[:, a_i] * mask[:, b_i]).sum() + 1e-8
    sim = F.cosine_similarity(a.abs(), b.abs(), dim=-1)
    sim = sim.sum(dim=(0, 1)) / count
    return sim.mean()

def similarity_loss(latents, mask):
    """
    Maximize the similarity between the same instrument latent representations
    Arguments:
        latents {torch.tensor} -- latent matrix from the encoder of shape: (B, 1, T', N)
        mask {torch.tensor} -- boolean mask: True when the signal is 0.0; shape (B, 4)
    Returns:
        torch.tensor -- shape: ()
    """
    a = latents
    b = torch.roll(latents, 1, dims=1)

    count = (mask * torch.roll(mask, 1, dims=0)).sum().clamp(min=1e-8)
    sim = F.cosine_similarity(a, b, dim=-1)
    sim = sim.sum(dim=(0, 1)) / count
    return sim.mean()

In [None]:
contoh = torch.randn((1, 500))
contoh_2 = F.pad(contoh, (0, 1000 - 500, 0, 0))
contoh_2.shape

torch.Size([1, 1000])

In [None]:
class Separation(sb.Brain):
    def compute_forward(self, mix, targets, stage, noise=None):
        """Forward computations from the mixture to the separated signals."""

        # Unpack lists and put tensors in the right device
        mix = mix.to(self.device)

        # Convert targets to tensor
        targets = torch.cat(
            [targets[i].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)


        # Separation
        mix_w_1 = self.modules.enc(mix)
        est_mask = self.modules.masker(mix_w_1)
        mix_w = torch.stack([mix_w_1] * self.hparams.num_spks)
        sep_h = mix_w * est_mask

        # Decoding
        est_source = torch.cat(
            [
                self.modules.dec(sep_h[i]).unsqueeze(-1)
                for i in range(self.hparams.num_spks)
            ],
            dim=-1,
        )

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_est = est_source.size(1)
        if T_origin > T_est:
            est_source = torch.nn.functional.pad(est_source, (0, 0, 0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin, :]

        true_latents = self.modules.enc(targets.view(self.hparams.num_spks, -1))
        true_latents = true_latents.view(1, 256, -1, self.hparams.num_spks)

        est_mix = self.modules.dec(mix_w_1)

        T_decmix = est_mix.size(1)

        if T_origin > T_decmix:
            est_mix = F.pad(est_mix, (0, T_origin - T_est, 0, 0))
        else:
            est_mix = est_mix[:, :T_origin]

        return est_source, targets, true_latents, est_mix

    def compute_objectives_2(self, predictions, targets, mask, true_latents, estimated_mix, true_mix):
        """Computes the sinr loss"""

        estimated_separation = predictions.permute(2, 0, 1)
        true_separation = targets.permute(2, 0, 1)
        true_latents = true_latents.permute(0, 3, 2,  1)

        sdr = sdr_objective(estimated_separation, true_separation, mask)
        total_loss = -sdr.sum()

        reconstruction_sdr = sdr_objective(estimated_mix, true_mix).mean() if self.hparams.reconstruction_loss_weight > 0 else 0.0
        total_loss += -self.hparams.reconstruction_loss_weight * reconstruction_sdr

        if self.hparams.similarity_loss_weight > 0.0 or self.hparams.dissimilarity_loss_weight > 0.0:
            true_latents = true_latents * mask.unsqueeze(-1).unsqueeze(-1)
            true_latents = true_latents.transpose(0, 1)

        dissimilarity = dissimilarity_loss(true_latents, mask) if self.hparams.dissimilarity_loss_weight > 0.0 else 0.0
        total_loss += self.hparams.dissimilarity_loss_weight * dissimilarity

        similarity = similarity_loss(true_latents, mask) if self.hparams.similarity_loss_weight > 0.0 else 0.0
        total_loss += -self.hparams.similarity_loss_weight * similarity
        return total_loss
    
    def compute_objectives(self, predictions, targets):
        return sb.nnet.losses.get_si_snr_with_pitwrapper(targets, predictions)

    def fit_batch(self, batch):
        """Trains one batch"""
        # Unpacking batch list
        mix, vocals, bass, drums, other, mask = batch[0], batch[1], batch[2], batch[3], batch[4], batch[5]
        mix, vocals, bass, drums, other, mask = mix.to(self.device), vocals.to(self.device), bass.to(self.device), drums.to(self.device), other.to(self.device), mask.to(self.device)
        targets = [vocals, bass, drums, other]

        predictions, targets, true_latents, est_mix = self.compute_forward(mix, targets, sb.Stage.TRAIN)
        
        #loss = self.compute_objectives_2(estimated_separation, true_separation, mask, true_latents, estimated_mix, true_mix)
        loss = self.compute_objectives_2(predictions, targets, mask, true_latents, est_mix, mix)
        #loss = self.compute_objectives(predictions, targets)

        th = -30
        loss_to_keep = loss[loss > th]
        if loss_to_keep.nelement() > 0:
            loss = loss_to_keep.mean()

        if (
            loss < 999999 and loss.nelement() > 0
        ):  # the fix for computational problems
            loss.backward()
            if 5 >= 0:
                torch.nn.utils.clip_grad_norm_(
                    self.modules.parameters(), 5
                )
            self.optimizer.step()
        else:
            nonfinite_count += 1
            print(
                "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                    nonfinite_count
                )
            )
            loss.data = torch.tensor(0).to(self.device)
        
        self.optimizer.zero_grad()

        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        mix, vocals, bass, drums, other, mask = batch[0], batch[1], batch[2], batch[3], batch[4], batch[5]
        mix, vocals, bass, drums, other, mask = mix.to(self.device), vocals.to(self.device), bass.to(self.device), drums.to(self.device), other.to(self.device), mask.to(self.device)
        targets = [vocals, bass, drums, other]

        with torch.no_grad():
            predictions, targets, true_latents, est_mix = self.compute_forward(
                mix, targets, stage)
            loss = self.compute_objectives_2(predictions, targets, mask, true_latents, est_mix, mix)


        return loss.detach()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {"si-snr": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:

            # Learning rate annealing
            if isinstance(
                self.hparams.lr_scheduler, sb.nnet.schedulers.ReduceLROnPlateau
            ):
                current_lr, next_lr = self.hparams.lr_scheduler(
                    [self.optimizer], epoch, stage_loss
                )
                sb.nnet.schedulers.update_learning_rate(self.optimizer, next_lr)
            else:
                # if we do not use the reducelronplateau, we do not change the lr
                current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]
            
            self.hparams.train_logger.log_stats(
                stats_meta={"epoch": epoch, "lr": current_lr},
                train_stats=self.train_stats,
                valid_stats=stage_stats)
            
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": 1},
                test_stats=stage_stats,
            )

    def reset_layer_recursively(self, layer):
        """Reinitializes the parameters of the neural networks"""
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
        for child_layer in layer.modules():
            if layer != child_layer:
                self.reset_layer_recursively(child_layer)

In [None]:
tf_blocks = {'SBtfintra': sb.lobes.models.dual_path.SBTransformerBlock(num_layers=8,
                                                                       d_model=256,
                                                                       nhead=8,
                                                                       d_ffn=1024,
                                                                       dropout=0,
                                                                       use_positional_encoding=True,
                                                                       norm_before=True),
             'SBtfinter': sb.lobes.models.dual_path.SBTransformerBlock(num_layers=8,
                                                                       d_model=256,
                                                                       nhead=8,
                                                                       d_ffn=1024,
                                                                       dropout=0,
                                                                       use_positional_encoding=True,
                                                                       norm_before=True),}

modules = {"enc": sb.lobes.models.dual_path.Encoder(kernel_size=16, out_channels=256),
           "masker": sb.lobes.models.dual_path.Dual_Path_Model(num_spks=4, 
                                                               in_channels=256, 
                                                               out_channels=256,
                                                               num_layers=2,
                                                               K=250,
                                                               intra_model=tf_blocks['SBtfintra'],
                                                               inter_model=tf_blocks['SBtfinter'],
                                                               norm='ln',
                                                               linear_layer_after_inter_intra=False,
                                                               skip_around_intra=True),
           "dec": sb.lobes.models.dual_path.Decoder(in_channels=256,
                                                    out_channels=1,
                                                    kernel_size=16,
                                                    stride=8,
                                                    bias=False)}

hparams = {'num_spks': 4,
           'lr_scheduler': sb.nnet.schedulers.ReduceLROnPlateau(factor=0.5,
                                                                patience=2,
                                                                dont_halve_until_epoch=85),
           'optimizer': lambda x: torch.optim.Adam(x, lr=0.00015),
           'train_logger': sb.utils.train_logger.FileTrainLogger('/content/drive/MyDrive/MusDB/train_log_1.txt'),
           'reconstruction_loss_weight': 0.05,
           'similarity_loss_weight': 2.0,
           'dissimilarity_loss_weight': 3.0}

brain = Separation(modules, hparams=hparams, opt_class=hparams['optimizer'], run_opts={"device":"cuda:0"})

In [None]:
for module in brain.modules.values():
    brain.reset_layer_recursively(module)

In [None]:
brain.fit(range(10), train_set=train_loader_audio, valid_set=valid_loader_audio)

100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=20.6]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=10.3]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=8.51]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=7.46]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=6.57]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=5.89]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=5.27]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=4.77]
100%|██████████| 14/14 [00:20<00:00,  1.46s/it]
100%|██████████| 80/80 [00:58<00:00,  1.36it/s, train_loss=4.29]
100%|██████████| 14/14 [00:20<0

# Wav

In [None]:
root_path = Path('/content/drive/MyDrive/MusDB/wav_16k')
root_path

validation_tracks = [
        "Actions - One Minute Smile",
        "Clara Berry And Wooldog - Waltz For My Victims",
        "Johnny Lokke - Promises & Lies",
        "Patrick Talbot - A Reason To Leave",
        "Triviul - Angelsaint",
        "Alexander Ross - Goodbye Bolero",
        "Fergessen - Nos Palpitants",
        "Leaf - Summerghost",
        "Skelpolu - Human Mistakes",
        "Young Griffo - Pennies",
        "ANiMAL - Rockshow",
        "James May - On The Line",
        "Meaxic - Take A Step",
        "Traffic Experiment - Sirens",
    ]

train_list = os.listdir('/content/drive/MyDrive/MusDB/wav/train')
train_subset = [song for song in train_list if song not in validation_tracks]
len(train_subset)

86

In [None]:
train_dataset = MUSDB18Dataset(root_path, split="train", subset=train_subset, random_track_mix=True, segment=6, sample_rate=16000)
valid_dataset = MUSDB18Dataset(root_path, split="train", subset=validation_tracks, sample_rate=16000)
test_dataset = MUSDB18Dataset(root_path, split="test", sample_rate=16000)

100it [02:05,  1.26s/it]
100it [00:19,  5.15it/s]
50it [01:11,  1.42s/it]


In [None]:
train_loader_audio = DataLoader(train_dataset, batch_size=1)
valid_loader_audio = DataLoader(valid_dataset, batch_size=1)
test_loader_audio = DataLoader(test_dataset, batch_size=1)

#torch.save(train_loader_audio, '/content/drive/MyDrive/MusDB/train_loader.pth')
#torch.save(valid_loader_audio, '/content/drive/MyDrive/MusDB/valid_loader.pth')
#torch.save(test_loader_audio, '/content/drive/MyDrive/MusDB/test_loader.pth')

In [None]:
train_loader_audio = torch.load('/content/drive/MyDrive/MusDB/train_loader.pth')
valid_loader_audio = torch.load('/content/drive/MyDrive/MusDB/valid_loader.pth')
test_loader_audio = torch.load('/content/drive/MyDrive/MusDB/test_loader.pth')

In [None]:
class Separation(sb.Brain):
    def compute_forward(self, mix, targets, stage, noise=None):
        """Forward computations from the mixture to the separated signals."""

        # Unpack lists and put tensors in the right device
        mix = mix.to(self.device)

        # Convert targets to tensor
        targets = torch.cat(
            [targets[i].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)


        # Separation
        mix_w = self.modules.enc(mix)
        est_mask = self.modules.masker(mix_w)
        mix_w = torch.stack([mix_w] * self.hparams.num_spks)
        sep_h = mix_w * est_mask

        # Decoding
        est_source = torch.cat(
            [
                self.modules.dec(sep_h[i]).unsqueeze(-1)
                for i in range(self.hparams.num_spks)
            ],
            dim=-1,
        )

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_est = est_source.size(1)
        if T_origin > T_est:
            est_source = torch.nn.functional.pad(est_source, (0, 0, 0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin, :]

        return est_source, targets

    def compute_objectives(self, predictions, targets):
        """Computes the sinr loss"""
        return sb.nnet.losses.get_si_snr_with_pitwrapper(targets, predictions)

    def fit_batch(self, batch):
        """Trains one batch"""
        # Unpacking batch list
        mix, sources = batch[0], batch[1]
        mix = mix[0]
        vocals = sources['vocals'][0]
        bass = sources['bass'][0]
        drums = sources['drums'][0]
        other = sources['other'][0]
        mix, vocals, bass, drums, other = mix.to(self.device), vocals.to(self.device), bass.to(self.device), drums.to(self.device), other.to(self.device)
        targets = [vocals, bass, drums, other]

        predictions, targets = self.compute_forward(
            mix, targets, sb.Stage.TRAIN)
        loss = self.compute_objectives(predictions, targets)

        th = -30
        loss_to_keep = loss[loss > th]
        if loss_to_keep.nelement() > 0:
            loss = loss_to_keep.mean()

        if (
            loss < 999999 and loss.nelement() > 0
        ):  # the fix for computational problems
            loss.backward()
            if 5 >= 0:
                torch.nn.utils.clip_grad_norm_(
                    self.modules.parameters(), 5
                )
            self.optimizer.step()
        else:
            nonfinite_count += 1
            print(
                "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                    nonfinite_count
                )
            )
            loss.data = torch.tensor(0).to(self.device)
        
        self.optimizer.zero_grad()

        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        mix, sources = batch[0], batch[1]
        mix = mix[0]
        vocals = sources['vocals'][0]
        bass = sources['bass'][0]
        drums = sources['drums'][0]
        other = sources['other'][0]
        mix, vocals, bass, drums, other = mix.to(self.device), vocals.to(self.device), bass.to(self.device), drums.to(self.device), other.to(self.device)
        targets = [vocals, bass, drums, other]

        with torch.no_grad():
            predictions, targets = self.compute_forward(mix, targets, stage)
            loss = self.compute_objectives(predictions, targets)


        return loss.detach()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {"si-snr": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:

            # Learning rate annealing
            if isinstance(
                self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
            ):
                current_lr, next_lr = self.hparams.lr_scheduler(
                    [self.optimizer], epoch, stage_loss
                )
                schedulers.update_learning_rate(self.optimizer, next_lr)
            else:
                # if we do not use the reducelronplateau, we do not change the lr
                current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]

    def reset_layer_recursively(self, layer):
        """Reinitializes the parameters of the neural networks"""
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
        for child_layer in layer.modules():
            if layer != child_layer:
                self.reset_layer_recursively(child_layer)

In [None]:
tf_blocks = {'SBtfintra': sb.lobes.models.dual_path.SBTransformerBlock(num_layers=8,
                                                                       d_model=256,
                                                                       nhead=8,
                                                                       d_ffn=1024,
                                                                       dropout=0,
                                                                       use_positional_encoding=True,
                                                                       norm_before=True),
             'SBtfinter': sb.lobes.models.dual_path.SBTransformerBlock(num_layers=8,
                                                                       d_model=256,
                                                                       nhead=8,
                                                                       d_ffn=1024,
                                                                       dropout=0,
                                                                       use_positional_encoding=True,
                                                                       norm_before=True),}

modules = {"enc": sb.lobes.models.dual_path.Encoder(kernel_size=16, out_channels=256),
           "masker": sb.lobes.models.dual_path.Dual_Path_Model(num_spks=4, 
                                                               in_channels=256, 
                                                               out_channels=256,
                                                               num_layers=2,
                                                               K=250,
                                                               intra_model=tf_blocks['SBtfintra'],
                                                               inter_model=tf_blocks['SBtfinter'],
                                                               norm='ln',
                                                               linear_layer_after_inter_intra=False,
                                                               skip_around_intra=True),
           "dec": sb.lobes.models.dual_path.Decoder(in_channels=256,
                                                    out_channels=1,
                                                    kernel_size=16,
                                                    stride=8,
                                                    bias=False)}

hparams = {'num_spks': 4,
           'lr_scheduler': sb.nnet.schedulers.ReduceLROnPlateau(factor=0.5,
                                                                patience=2,
                                                                dont_halve_until_epoch=85),
           'optimizer': lambda x: torch.optim.Adam(x, lr=0.00015)}

brain = Separation(modules, hparams=hparams, opt_class=hparams['optimizer'], run_opts={"device":"cuda:0"})

In [None]:
for module in brain.modules.values():
    brain.reset_layer_recursively(module)

In [None]:
brain.fit(range(100), train_set=train_loader_audio)

  0%|          | 0/86 [00:01<?, ?it/s]


RuntimeError: ignored

In [None]:
contoh_1 = next(iter(test_loader_audio))
mix, vocals, bass, drums, other = contoh_1[0], contoh_1[1], contoh_1[2], contoh_1[3], contoh_1[4]
#mix, vocal, accompaniment = mix.to(device), vocal.to(device), accompaniment.to(device)
targets = [vocals, bass, drums, other]

In [None]:
predictions, new_targets = brain.compute_forward(mix, targets, sb.Stage.TEST)

In [None]:
from IPython.display import Audio

Audio(mix.squeeze().detach(), rate=8000)

In [None]:
Audio(new_targets[:, :, 3].cpu().squeeze().detach(), rate=8000)

In [None]:
Audio(predictions[:, :, 3].cpu().squeeze().detach(), rate=8000)