In [3]:
from google.colab import drive
drive.mount("/content/gdrive")
% cd gdrive/MyDrive/'Colab Notebooks'/FullSubNet


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
[Errno 2] No such file or directory: 'gdrive/MyDrive/Colab Notebooks/FullSubNet'
/content/gdrive/MyDrive/Colab Notebooks/FullSubNet


In [4]:
import matplotlib.pyplot as plt
import torch
from torch.cuda.amp import autocast
from tqdm import tqdm


In [None]:

# from audio_zen.trainer.base_trainer import BaseTrainer
# from audio_zen.acoustics.mask import build_complex_ideal_ratio_mask, decompress_cIRM
# from audio_zen.acoustics.feature import mag_phase

In [9]:
plt.switch_backend('agg')

In [10]:
def build_complex_ideal_ratio_mask(noisy: torch.complex64, clean: torch.complex64) -> torch.Tensor:
    """

    Args:
        noisy: [B, F, T], noisy complex-valued stft coefficients
        clean: [B, F, T], clean complex-valued stft coefficients

    Returns:
        [B, F, T, 2]
    """
    denominator = torch.square(noisy.real) + torch.square(noisy.imag) + EPSILON

    mask_real = (noisy.real * clean.real + noisy.imag * clean.imag) / denominator
    mask_imag = (noisy.real * clean.imag - noisy.imag * clean.real) / denominator

    complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1)

    return compress_cIRM(complex_ratio_mask, K=10, C=0.1)


def compress_cIRM(mask, K=10, C=0.1):
    """
        Compress from (-inf, +inf) to [-K ~ K]
    """
    if torch.is_tensor(mask):
        mask = -100 * (mask <= -100) + mask * (mask > -100)
        mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask))
    else:
        mask = -100 * (mask <= -100) + mask * (mask > -100)
        mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask))
    return mask


def decompress_cIRM(mask, K=10, limit=9.9):
    mask = limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit)
    mask = -K * torch.log((K - mask) / (K + mask))
    return mask


In [11]:
def mag_phase(complex_tensor):
    return torch.abs(complex_tensor), torch.angle(complex_tensor)

### base_trainer

In [12]:
#base_trainer.ipynb
class BaseTrainer:
    def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer):
        torch.cuda.set_device(rank)
        self.model = DistributedDataParallel(model.cuda(rank), device_ids=[rank])
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        #It may be due to K80 or cuda problems.
        torch.backends.cudnn.enabled = config["meta"]["cudnn_enable"]
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustics"]
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]

        # Supported STFT
        self.torch_stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        self.torch_istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length)

        # Trainer.train in the config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config["save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one."

        # Trainer.validation in the config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config["validation_interval"]
        self.save_max_metric_score = self.validation_config["save_max_metric_score"]
        assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one."

        # Trainer.visualization in the config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute() / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        # Debug validation, which skips training
        self.only_validation = only_validation

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["meta"]["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

            self.writer = SummaryWriter(self.logs_dir.as_posix(), max_queue=5, flush_secs=30)
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1
            )

            print("The configurations are as follows: ")
            print(config)  # except "\n"

            with open((self.save_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])

    def _preload_model(self, model_path):
        """
        Preload model parameters (in "*.tar" format) at the start of experiment.

        Args:
            model_path (Path): The file path of the *.tar file
        """
        model_path = model_path.expanduser().absolute()
        assert model_path.exists(), f"The file {model_path.as_posix()} is not exist. please check path."

        model_checkpoint = torch.load(model_path.as_posix(), map_location="cpu")
        self.model.load_state_dict(model_checkpoint["model"], strict=False)
        self.model.to(self.rank)

        if self.rank == 0:
            print(f"Model preloaded successfully from {model_path.as_posix()}.")

    def _resume_checkpoint(self):
        """
        Resume the experiment from the latest checkpoint.
        """
        latest_model_path = self.checkpoints_dir.expanduser().absolute() / "latest_model.tar"
        assert latest_model_path.exists(), f"{latest_model_path} does not exist, can not load latest checkpoint."

        # Load it on the CPU and later use .to(device) on the model
        # Maybe slightly slow than use map_location="cuda:<...>"
        # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion
        checkpoint = torch.load(latest_model_path.as_posix(), map_location="cpu")

        # Make sure all processes (GPUs) do not start loading before the saving is finished.
        # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work
        self.dist.barrier()

        self.start_epoch = checkpoint["epoch"] + 1
        self.best_score = checkpoint["best_score"]
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scaler.load_state_dict(checkpoint["scaler"])

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            self.model.module.load_state_dict(checkpoint["model"])
        else:
            self.model.load_state_dict(checkpoint["model"])

        # self.model.to(self.rank)

        if self.rank == 0:
            print(f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch.")

    def _save_checkpoint(self, epoch, is_best_epoch=False):
        """
        Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of:
            - epoch
            - best metric score in historical epochs
            - optimizer parameters
            - model parameters

        Args:
            is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True),
                                the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar".
        """
        print(f"\t Saving {epoch} epoch model checkpoint...")

        state_dict = {
            "epoch": epoch,
            "best_score": self.best_score,
            "optimizer": self.optimizer.state_dict(),
            "scaler": self.scaler.state_dict()
        }

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            state_dict["model"] = self.model.module.state_dict()
        else:
            state_dict["model"] = self.model.state_dict()

        # Saved in "latest_model.tar"
        # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc.
        # New checkpoint will overwrite the older one.
        torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix())

        # "model_{epoch_number}.pth"
        # Contains only model.
        torch.save(state_dict["model"], (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth").as_posix())

        # If the model get a best metric score (means "is_best_epoch=True") in the current epoch,
        # the model checkpoint will be saved as "best_model.tar"
        # The newer best-scored checkpoint will overwrite the older one.
        if is_best_epoch:
            print(f"\t :smiley: Found a best score in the {epoch} epoch, saving...")
            torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix())

    def _is_best_epoch(self, score, save_max_metric_score=True):
        """
        Check if the current model got the best metric score
        """
        if save_max_metric_score and score >= self.best_score:
            self.best_score = score
            return True
        elif not save_max_metric_score and score <= self.best_score:
            self.best_score = score
            return True
        else:
            return False

    @staticmethod
    def _print_networks(models: list):
        print(f"This project contains {len(models)} models, the number of the parameters is: ")

        params_of_all_networks = 0
        for idx, model in enumerate(models, start=1):
            params_of_network = 0
            for param in model.parameters():
                params_of_network += param.numel()

            print(f"\tNetwork {idx}: {params_of_network / 1e6} million.")
            params_of_all_networks += params_of_network

        print(f"The amount of parameters in the project is {params_of_all_networks / 1e6} million.")

    def _set_models_to_train_mode(self):
        self.model.train()

    def _set_models_to_eval_mode(self):
        self.model.eval()

    def spec_audio_visualization(self, noisy, enhanced, clean, name, epoch, mark=""):
        self.writer.add_audio(f"{mark}_Speech/{name}_Noisy", noisy, epoch, sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Clean", clean, epoch, sample_rate=16000)

        # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech
        noisy_mag, _ = librosa.magphase(self.librosa_stft(noisy, n_fft=320, hop_length=160, win_length=320))
        enhanced_mag, _ = librosa.magphase(self.librosa_stft(enhanced, n_fft=320, hop_length=160, win_length=320))
        clean_mag, _ = librosa.magphase(self.librosa_stft(clean, n_fft=320, hop_length=160, win_length=320))
        fig, axes = plt.subplots(3, 1, figsize=(6, 6))
        for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]):
            axes[k].set_title(
                f"mean: {np.mean(mag):.3f}, "
                f"std: {np.std(mag):.3f}, "
                f"max: {np.max(mag):.3f}, "
                f"min: {np.min(mag):.3f}"
            )
            librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000)
        plt.tight_layout()
        self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch)

    def metrics_visualization(self, noisy_list, clean_list, enhanced_list, metrics_list, epoch, num_workers=10,
                              mark=""):
        """
        Get metrics on validation dataset by paralleling.

        Notes:
            1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are
             used for checking if the current epoch is a "best epoch."
            2. If you want to use a new metric, you must register it in "util.metrics" file.
        """
        assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be exist."

        # Check if the metric is registered in "util.metrics" file.
        for i in metrics_list:
            assert i in metrics.REGISTERED_METRICS.keys(), f"{i} is not registered, please check 'util.metrics' file."

        stoi_mean = 0.0
        wb_pesq_mean = 0.0
        for metric_name in metrics_list:
            score_on_noisy = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, noisy_list)
            )
            score_on_enhanced = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in
                zip(clean_list, enhanced_list)
            )

            # Add mean value of the metric to tensorboard
            mean_score_on_noisy = np.mean(score_on_noisy)
            mean_score_on_enhanced = np.mean(score_on_enhanced)
            self.writer.add_scalars(f"{mark}_Validation/{metric_name}", {
                "Noisy": mean_score_on_noisy,
                "Enhanced": mean_score_on_enhanced
            }, epoch)

            if metric_name == "STOI":
                stoi_mean = mean_score_on_enhanced

            if metric_name == "WB_PESQ":
                wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced)

        return (stoi_mean + wb_pesq_mean) / 2

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            if self.rank == 0:
                print(f"{'=' * 15} {epoch} epoch {'=' * 15}")
                print("[0 seconds] Begin training...")

            # [debug validation] Only run validation (only use the first GPU (process))
            # inference + calculating metrics + saving checkpoints
            if self.only_validation and self.rank == 0:
                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(metric_score, save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

                # Skip the following regular training, saving checkpoints, and validation
                continue

            # Regular training
            timer = ExecutionTime()
            self._set_models_to_train_mode()
            self._train_epoch(epoch)

            #  Regular save checkpoints
            if self.rank == 0 and self.save_checkpoint_interval != 0 and (epoch % self.save_checkpoint_interval == 0):
                self._save_checkpoint(epoch)

            # Regular validation
            if self.rank == 0 and (epoch % self.validation_interval == 0):
                print(f"[{timer.duration()} seconds] Training has finished, validation is in progress...")

                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(metric_score, save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

            if self.rank == 0:
                print(f"[{timer.duration()} seconds] This epoch is finished.")

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _validation_epoch(self, epoch):
        raise NotImplementedError





In [13]:

class Trainer(BaseTrainer):
    def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer, train_dataloader, validation_dataloader):
        super().__init__(dist, rank, config, resume, only_validation, model, loss_function, optimizer)
        self.train_dataloader = train_dataloader
        self.valid_dataloader = validation_dataloader

    def _train_epoch(self, epoch):
        loss_total = 0.0

        for noisy, clean in tqdm(self.train_dataloader, desc=f"Training {self.rank}"):
            self.optimizer.zero_grad()

            noisy = noisy.to(self.rank)
            clean = clean.to(self.rank)

            noisy_complex = self.torch_stft(noisy)
            clean_complex = self.torch_stft(clean)

            noisy_mag, _ = mag_phase(noisy_complex)
            cIRM = build_complex_ideal_ratio_mask(noisy_complex, clean_complex)  # [B, F, T, 2]

            with autocast(enabled=self.use_amp):
                # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2]
                noisy_mag = noisy_mag.unsqueeze(1)
                cRM = self.model(noisy_mag)
                cRM = cRM.permute(0, 2, 3, 1)
                loss = self.loss_function(cIRM, cRM)

            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value)
            self.scaler.step(self.optimizer)
            self.scaler.update()

            loss_total += loss.item()

        if self.rank == 0:
            self.writer.add_scalar(f"Loss/Train", loss_total / len(self.train_dataloader), epoch)

    @torch.no_grad()
    def _validation_epoch(self, epoch):
        visualization_n_samples = self.visualization_config["n_samples"]
        visualization_num_workers = self.visualization_config["num_workers"]
        visualization_metrics = self.visualization_config["metrics"]

        loss_total = 0.0
        loss_list = {"With_reverb": 0.0, "No_reverb": 0.0, }
        item_idx_list = {"With_reverb": 0, "No_reverb": 0, }
        noisy_y_list = {"With_reverb": [], "No_reverb": [], }
        clean_y_list = {"With_reverb": [], "No_reverb": [], }
        enhanced_y_list = {"With_reverb": [], "No_reverb": [], }
        validation_score_list = {"With_reverb": 0.0, "No_reverb": 0.0}

        # speech_type in ("with_reverb", "no_reverb")
        for i, (noisy, clean, name, speech_type) in tqdm(enumerate(self.valid_dataloader), desc="Validation"):
            assert len(name) == 1, "The batch size for the validation stage must be one."
            name = name[0]
            speech_type = speech_type[0]

            noisy = noisy.to(self.rank)
            clean = clean.to(self.rank)

            noisy_complex = self.torch_stft(noisy)
            clean_complex = self.torch_stft(clean)

            noisy_mag, _ = mag_phase(noisy_complex)
            cIRM = build_complex_ideal_ratio_mask(noisy_complex, clean_complex)  # [B, F, T, 2]

            noisy_mag = noisy_mag.unsqueeze(1)
            cRM = self.model(noisy_mag)
            cRM = cRM.permute(0, 2, 3, 1)

            loss = self.loss_function(cIRM, cRM)

            cRM = decompress_cIRM(cRM)

            enhanced_real = cRM[..., 0] * noisy_complex.real - cRM[..., 1] * noisy_complex.imag
            enhanced_imag = cRM[..., 1] * noisy_complex.real + cRM[..., 0] * noisy_complex.imag
            enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1)
            enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1))

            noisy = noisy.detach().squeeze(0).cpu().numpy()
            clean = clean.detach().squeeze(0).cpu().numpy()
            enhanced = enhanced.detach().squeeze(0).cpu().numpy()

            assert len(noisy) == len(clean) == len(enhanced)
            loss_total += loss

            # Separated loss
            loss_list[speech_type] += loss
            item_idx_list[speech_type] += 1

            if item_idx_list[speech_type] <= visualization_n_samples:
                self.spec_audio_visualization(noisy, enhanced, clean, name, epoch, mark=speech_type)

            noisy_y_list[speech_type].append(noisy)
            clean_y_list[speech_type].append(clean)
            enhanced_y_list[speech_type].append(enhanced)

        self.writer.add_scalar(f"Loss/Validation_Total", loss_total / len(self.valid_dataloader), epoch)

        for speech_type in ("With_reverb", "No_reverb"):
            self.writer.add_scalar(f"Loss/{speech_type}", loss_list[speech_type] / len(self.valid_dataloader), epoch)

            validation_score_list[speech_type] = self.metrics_visualization(
                noisy_y_list[speech_type], clean_y_list[speech_type], enhanced_y_list[speech_type],
                visualization_metrics, epoch, visualization_num_workers, mark=speech_type
            )

        return validation_score_list["No_reverb"]
