<a href="https://colab.research.google.com/github/nschmidtg/thesis/blob/main/Test1_as_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# In Google Colab: Be sure to select a GPU runtime (Runtime → Change runtime type → Hardware accelarator).


In [1]:
# First off, install asteroid
!pip install git+https://github.com/asteroid-team/asteroid --quiet

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
  Building wheel for asteroid (PEP 517) ... [?25l[?25hdone


## After installing requirements, you need to Restart Runtime (Ctrl + M).
Else it will fail to import asteroid

In [2]:
!pip install pytorch-lightning --quiet

In [3]:
# Asteroid is based on PyTorch and PyTorch-Lightning.
from torch import optim
from pytorch_lightning import Trainer

In [4]:
# We train the same model architecture that we used for inference above.
from asteroid import DPRNNTasNet

In [5]:
# In this example we use Permutation Invariant Training (PIT) and the SI-SDR loss.
from asteroid.losses import pairwise_neg_sisdr, PITLossWrapper

In [6]:
# install musdb:
!pip install musdb --quiet

[?25l[K     |▋                               | 10kB 28.9MB/s eta 0:00:01[K     |█▎                              | 20kB 34.4MB/s eta 0:00:01[K     |██                              | 30kB 37.1MB/s eta 0:00:01[K     |██▋                             | 40kB 33.2MB/s eta 0:00:01[K     |███▏                            | 51kB 35.3MB/s eta 0:00:01[K     |███▉                            | 61kB 27.4MB/s eta 0:00:01[K     |████▌                           | 71kB 28.9MB/s eta 0:00:01[K     |█████▏                          | 81kB 25.8MB/s eta 0:00:01[K     |█████▉                          | 92kB 27.4MB/s eta 0:00:01[K     |██████▍                         | 102kB 25.2MB/s eta 0:00:01[K     |███████                         | 112kB 25.2MB/s eta 0:00:01[K     |███████▊                        | 122kB 25.2MB/s eta 0:00:01[K     |████████▍                       | 133kB 25.2MB/s eta 0:00:01[K     |█████████                       | 143kB 25.2MB/s eta 0:00:01[K     |█████████▋   

In [7]:
# install ffmpeg (stems are mp4 by default)
!sudo apt-get install ffmpeg

Reading package lists... Done
Building dependency tree       
Reading state information... Done
ffmpeg is already the newest version (7:3.4.8-0ubuntu0.2).
0 upgraded, 0 newly installed, 0 to remove and 15 not upgraded.


In [6]:
# MiniLibriMix is a tiny version of LibriMix (https://github.com/JorisCos/LibriMix),
# which is a free speech separation dataset.
from asteroid.data import LibriMix

# import musdb to create the mixtures: https://github.com/sigsep/sigsep-mus-db
import musdb
# Asteroid's System is a convenience wrapper for PyTorch-Lightning.
from asteroid.engine import System

from IPython.display import display, Audio


In [9]:
# download the musdb library
mus = musdb.DB(download=True)

# To use the full dataset, set a dataset root directory
# mus = musdb.DB(root="/path/to/musdb)

# To work directly with wav: https://github.com/sigsep/sigsep-mus-db#using-wav-files-optional

Downloading MUSDB 7s Sample Dataset to /root/MUSDB18/MUSDB18-7...
Done!


In [10]:
# This will automatically download MiniLibriMix from Zenodo on the first run.
train_loader, val_loader = LibriMix.loaders_from_mini(task="sep_clean", batch_size=8)

HBox(children=(FloatProgress(value=0.0, max=640547371.0), HTML(value='')))


Drop 0 utterances from 800 (shorter than 3 seconds)
Drop 0 utterances from 200 (shorter than 3 seconds)


# Create the augmented dataset

using the LibriMix and the Musdb18 datasets, an augmented podcast/radioshow like dataset is created

In [13]:
import librosa, os

In [12]:
def create_folder_structure(path):
    if not os.path.exists(path):
        os.makedirs(path)
    if not os.path.exists(path + '/linear_mono'):
        os.makedirs(path + '/linear_mono')
    if not os.path.exists(path + '/linear_stereo'):
        os.makedirs(path + '/linear_stereo')
    if not os.path.exists(path + '/sidechain_mono'):
        os.makedirs(path + '/sidechain_mono')
    if not os.path.exists(path + '/sidechain_stereo'):
        os.makedirs(path + '/sidechain_stereo')
    if not os.path.exists(path + '/track_mono'):
        os.makedirs(path + '/track_mono')
    if not os.path.exists(path + '/track_stereo'):
        os.makedirs(path + '/track_stereo')
    if not os.path.exists(path + '/speech_mono'):
        os.makedirs(path + '/speech_mono')

In [13]:
# create files structure
train_path = "augmented_dataset/train"
create_folder_structure(train_path)

val_path = "augmented_dataset/val"
create_folder_structure(val_path)

if not os.path.exists('augmented_dataset/metadata'):
    os.makedirs('augmented_dataset/metadata')
if not os.path.exists('augmented_dataset/metadata/train'):
    os.makedirs('augmented_dataset/metadata/train')
if not os.path.exists('augmented_dataset/metadata/val'):
    os.makedirs('augmented_dataset/metadata/val')

In [10]:
from os import listdir
from os.path import isfile, join
import random
import numpy as np
import re
import csv

In [15]:
def mix_audio_sources(track_path, speech_path, output_path, music_to_speech_ratio = 0.2):
    """
    Creates 4 mixes for the a music and a speech track and locates it in the output_path
    the 4 mixes are: linear_mono, linear_stereo, sidechain_mono, sidechain_stereo
    librimix is mono and musdb stereo
    """
    # read the files
    track, fs_track = librosa.load(track_path, sr=44100, mono=False)
    speech, fs_speech = librosa.load(speech_path, sr=44100)
    # match the length of the files
    min_lenght = min(len(track[0]), len(speech))
    
    # crop the files to match in length
    cropped_track_stereo = np.array([track[0][0:min_lenght], track[1][0:min_lenght]])
    cropped_track_mono = cropped_track_stereo[0] + cropped_track_stereo[1]
    cropped_speech = speech[0:min_lenght]
    
    linear_stereo = cropped_track_stereo * music_to_speech_ratio + cropped_speech
    linear_mono = cropped_track_mono * music_to_speech_ratio + cropped_speech
    
    # write the files

    
    file_name = re.sub("[^0-9a-zA-Z]+", "-", track_path.split('/')[-1]) + '_' + speech_path.split('/')[-1]
    librosa.output.write_wav(output_path + "/linear_mono/" + file_name, linear_mono, 44100, norm=True)
    librosa.output.write_wav(output_path + "/linear_stereo/" + file_name, linear_stereo, 44100, norm=True)
    librosa.output.write_wav(output_path + "/speech_mono/" + file_name, cropped_speech, 44100, norm=True)
    librosa.output.write_wav(output_path + "/track_mono/" + file_name, cropped_track_mono, 44100, norm=True)
    librosa.output.write_wav(output_path + "/track_stereo/" + file_name, cropped_track_stereo, 44100, norm=True)

    return file_name, min_lenght

In [16]:
speech_path_train = "MiniLibriMix/val/s1/"
speech_path_val = "MiniLibriMix/val/s2/"
    
speech_array_train = [f for f in listdir(speech_path_train) if isfile(join(speech_path_train, f))]
speech_array_val = [f for f in listdir(speech_path_val) if isfile(join(speech_path_val, f))]

In [17]:
random.seed(1)

# train
csv_path = 'augmented_dataset/metadata/train/linear_stereo.csv'
with open(csv_path, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["","mixture_ID","mixture_path","track_path","speech_path","length"])

csv_path = 'augmented_dataset/metadata/train/linear_mono.csv'
with open(csv_path, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["","mixture_ID","mixture_path","track_path","speech_path","length"])

# val
csv_path = 'augmented_dataset/metadata/val/linear_stereo.csv'
with open(csv_path, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["","mixture_ID","mixture_path","track_path","speech_path","length"])

csv_path = 'augmented_dataset/metadata/val/linear_mono.csv'
with open(csv_path, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["","mixture_ID","mixture_path","track_path","speech_path","length"])

In [18]:
# create the train/val

# 100 train, 50 val
n_train = 100
i = 0
for track in mus:
    
    track_file = track.path
    # get the speech name for the csv
    if i < n_train:
        speech_name = speech_array_train[random.randint(0,len(speech_array_train)-1)]
        path = train_path
        csv_path = 'augmented_dataset/metadata/train'
        speech_path = speech_path_train
        
    else:
        speech_name = speech_array_val[random.randint(0,len(speech_array_val)-1)]
        path = val_path
        csv_path = 'augmented_dataset/metadata/val'
        speech_path = speech_path_val

    # path of the speech
    speech_file = speech_path + speech_name
    
    file_name, min_length = mix_audio_sources(track_file, speech_file, path, music_to_speech_ratio= 0.1)


    csv_path_file = csv_path + '/linear_mono.csv'
    with open(csv_path_file, 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([
                        i,
                        file_name,
                        path + "/linear_mono/" + file_name,
                        path + "/track_mono/" + file_name,
                        path + "/speech_mono/" + file_name,
                        min_length
            ])
    
    csv_path_file = csv_path + '/linear_stereo.csv'
    with open(csv_path_file, 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([
                        i,
                        file_name,
                        path + "/linear_stereo/" + file_name,
                        path + "/track_stereo/" + file_name,
                        path + "/speech_stereo/" + file_name,
                        min_length
            ])
    
    i += 1

## Create the DataLoader object

In [7]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import soundfile as sf
import torch

In [14]:
class PodcastMix(Dataset):
    """Dataset class for PodcastMix source separation tasks.
    Args:
        csv_dir (str): The path to the metadata file.
        task (str): One of ``'linear_mono'``, ``'linear_stereo'``, 
            ``'sidechain_mono'`` or ``'sidechain_stereo'`` :
            * ``'linear_mono'`` for linear_mono mix
            * ``'linear_stereo'`` for linear_stereo mix
        sample_rate (int) : The sample rate of the sources and mixtures.
        n_src (int) : The number of sources in the mixture.
        segment (int) : The desired sources and mixtures length in s.
    References
        [1] "LibriMix: An Open-Source Dataset for Generalizable Speech Separation",
        Cosentino et al. 2020.
        [2] "MUSDB18 - a corpus for music separation",
        Zafar et al. 2018.
    """

    dataset_name = "PodcastMix"

    def __init__(self, csv_dir, task="linear_mono", sample_rate=44100, n_src=2, segment=3):
        self.csv_dir = csv_dir
        self.task = task
        # Get the csv corresponding to the task
        md_file = [f for f in os.listdir(csv_dir) if task in f][0]
        self.csv_path = os.path.join(self.csv_dir, md_file)
        self.segment = segment
        self.sample_rate = sample_rate
        # Open csv file
        self.df = pd.read_csv(self.csv_path, engine='python')
        # Get rid of the utterances too short
        if self.segment is not None:
            max_len = len(self.df)
            self.seg_len = int(self.segment * self.sample_rate)
            # Ignore the file shorter than the desired_length
            self.df = self.df[self.df["length"] >= self.seg_len]
            print(
                f"Drop {max_len - len(self.df)} utterances from {max_len} "
                f"(shorter than {segment} seconds)"
            )
        else:
            self.seg_len = None
        self.n_src = n_src

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get the row in dataframe
        row = self.df.iloc[idx]
        # Get mixture path
        self.mixture_path = row["mixture_path"]
        sources_list = []
        # If there is a seg start point is set randomly
        if self.seg_len is not None:
            start = random.randint(0, row["length"] - self.seg_len)
            stop = start + self.seg_len
        else:
            start = 0
            stop = None
        # # If task is enh_both then the source is the clean mixture
        # if "enh_both" in self.task:
        #     mix_clean_path = self.df_clean.iloc[idx]["mixture_path"]
        #     s, _ = sf.read(mix_clean_path, dtype="float32", start=start, stop=stop)
        #     sources_list.append(s)

        # else:
            # Read sources
            # for i in range(self.n_src):
        source_path = row["track_path"]
        s, _ = sf.read(source_path, dtype="float32", start=start, stop=stop)
        sources_list.append(s)

        source_path = row["speech_path"]
        s, _ = sf.read(source_path, dtype="float32", start=start, stop=stop)
        sources_list.append(s)
        # Read the mixture
        mixture, _ = sf.read(self.mixture_path, dtype="float32", start=start, stop=stop)
        # Convert to torch tensor
        mixture = torch.from_numpy(mixture)
        # Stack sources
        sources = np.vstack(sources_list)
        # Convert sources to tensor
        sources = torch.from_numpy(sources)
        return mixture, sources

    @classmethod
    def loaders_from_mini(cls, batch_size=4, **kwargs):
        """Downloads MiniLibriMix and returns train and validation DataLoader.
        Args:
            batch_size (int): Batch size of the Dataloader. Only DataLoader param.
                To have more control on Dataloader, call `mini_from_download` and
                instantiate the DatalLoader.
            **kwargs: keyword arguments to pass the `LibriMix`, see `__init__`.
                The kwargs will be fed to both the training set and validation
                set.
        Returns:
            train_loader, val_loader: training and validation DataLoader out of
            `LibriMix` Dataset.
        Examples
            >>> from asteroid.data import LibriMix
            >>> train_loader, val_loader = LibriMix.loaders_from_mini(
            >>>     task='sep_clean', batch_size=4
            >>> )
        """
        train_set, val_set = cls.mini_from_download(**kwargs)
        train_loader = DataLoader(train_set, batch_size=batch_size, drop_last=True)
        val_loader = DataLoader(val_set, batch_size=batch_size, drop_last=True)
        return train_loader, val_loader

    @classmethod
    def mini_from_download(cls, **kwargs):
        """Downloads MiniLibriMix and returns train and validation Dataset.
        If you want to instantiate the Dataset by yourself, call
        `mini_download` that returns the path to the path to the metadata files.
        Args:
            **kwargs: keyword arguments to pass the `LibriMix`, see `__init__`.
                The kwargs will be fed to both the training set and validation
                set
        Returns:
            train_set, val_set: training and validation instances of
            `LibriMix` (data.Dataset).
        Examples
            >>> from asteroid.data import LibriMix
            >>> train_set, val_set = LibriMix.mini_from_download(task='sep_clean')
        """
        # kwargs checks
        assert "csv_dir" not in kwargs, "Cannot specify csv_dir when downloading."
        # assert kwargs.get("task", "sep_clean") in [
        #     "sep_clean",
        #     "sep_noisy",
        # ], "Only clean and noisy separation are supported in MiniLibriMix."
        assert (
            kwargs.get("sample_rate", 44100) == 44100
        ), "Only 44100kHz sample rate is supported in MiniLibriMix."
        # Download LibriMix in current directory
        meta_path = 'augmented_dataset/metadata'
        # Create dataset instances
        train_set = cls(os.path.join(meta_path, "train"), sample_rate=44100, **kwargs)
        val_set = cls(os.path.join(meta_path, "val"), sample_rate=44100, **kwargs)
        return train_set, val_set

In [15]:
train_loader, val_loader = PodcastMix.loaders_from_mini(task="linear_mono", batch_size=2)

Drop 0 utterances from 100 (shorter than 3 seconds)
Drop 0 utterances from 44 (shorter than 3 seconds)


# Train the network

In [16]:
# Tell DPRNN that we want to separate to 2 sources.
model = DPRNNTasNet(n_src=2)

In [17]:
# PITLossWrapper works with any loss function.
loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

optimizer = optim.Adam(model.parameters(), lr=1e-3)

system = System(model, optimizer, loss, train_loader, val_loader)

In [32]:
# Train for 1 epoch using a single GPU. If you're running this on Google Colab,
# be sure to select a GPU runtime (Runtime → Change runtime type → Hardware accelarator).
trainer = Trainer(max_epochs=50, gpus=1)
trainer.fit(system)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type           | Params
---------------------------------------------
0 | model     | DPRNNTasNet    | 3.7 M 
1 | loss_func | PITLossWrapper | 0     
---------------------------------------------
3.7 M     Trainable params
0         Non-trainable params
3.7 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [19]:
# !pip install librosa --quiet

In [33]:
import librosa

# Or simply a file name:
model.separate("/content/augmented_dataset/val/linear_mono/Ben-Carrigan-We-ll-Talk-About-It-All-Tonight-stem-mp4_1993-147964-0004_6345-93302-0016.wav", resample=True)



In [34]:
from IPython.display import display, Audio

display(Audio("/content/augmented_dataset/val/linear_mono/Ben-Carrigan-We-ll-Talk-About-It-All-Tonight-stem-mp4_1993-147964-0004_6345-93302-0016.wav"))
display(Audio("/content/augmented_dataset/val/linear_mono/Ben-Carrigan-We-ll-Talk-About-It-All-Tonight-stem-mp4_1993-147964-0004_6345-93302-0016_est1.wav"))
display(Audio("/content/augmented_dataset/val/linear_mono/Ben-Carrigan-We-ll-Talk-About-It-All-Tonight-stem-mp4_1993-147964-0004_6345-93302-0016_est2.wav"))

# Try to use ConvTasNet


In [None]:
from asteroid import ConvTasNet

In [None]:
# Tell DPRNN that we want to separate to 2 sources.
model = ConvTasNet(n_src=2)

In [None]:
# PITLossWrapper works with any loss function.
loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

optimizer = optim.Adam(model.parameters(), lr=1e-3)

system = System(model, optimizer, loss, train_loader, val_loader)

In [None]:
# Train for 1 epoch using a single GPU. If you're running this on Google Colab,
# be sure to select a GPU runtime (Runtime → Change runtime type → Hardware accelarator).
trainer = Trainer(max_epochs=1, gpus=1)
trainer.fit(system)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type           | Params
---------------------------------------------
0 | model     | ConvTasNet     | 5 M   
1 | loss_func | PITLossWrapper | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1