In [12]:
import argparse
import os

import torch
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
from torchaudio.models import ConvTasNet
from torchsummary import summary
from torchvision import transforms
from tqdm import tqdm
import wandb

from getmodel import get_model
import numpy as np

In [3]:
class Trainer:
    def __init__(
        self,
        train_data,
        val_data,
        checkpoint_name,
        display_freq=10,
        useWandB = False
    ):
        self.train_data = train_data
        self.val_data = val_data
        assert checkpoint_name.endswith(".tar"), "The checkpoint file must have .tar extension"
        self.checkpoint_name = checkpoint_name
        self.display_freq = display_freq
        self.useWandB = useWandB

    def fit(
        self,
        model,
        device,
        epochs=10,
        batch_size=16,
        lr=0.001,
        weight_decay=1e-5,
        optimizer=optim.Adam,
        loss_fn=F.mse_loss,
        loss_mode="min",
        gradient_clipping=True,
    ):
        
        if(self.useWandB):
              print("Using WandB")
              wandb.init(
                        # set the wandb project where this run will be logged
                        project="ConvTasnetImpl",
                        name= f'ConvTasNet: Initial Runs',
                        config={
                            "epochs":epochs,
                            "learning_rate":lr ,
                            "batch_size": batch_size,
                        }
                    )
              wandb.log({"model trainable parameters": sum(p.numel() for p in model.parameters() if p.requires_grad)})

        # Get the device placement and make data loaders
        self.device = device
        print(f"Using device: {self.device}")
        kwargs = {"num_workers": 1, "pin_memory": True} if device == "cuda" else {}
        self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=batch_size,generator = torch.Generator(device=device),**kwargs)
        self.val_loader = torch.utils.data.DataLoader(self.val_data, batch_size=batch_size,generator = torch.Generator(device=device), **kwargs)

        self.optimizer = optimizer(model.parameters(), lr=lr, weight_decay=weight_decay)
        self.loss_fn = loss_fn
        self.loss_mode = loss_mode
        self.gradient_clipping = gradient_clipping
        self.history = {"train_loss": [], "test_loss": []}

        previous_epochs = 0
        best_loss = None

        # Try loading checkpoint (if it exists)
        if os.path.isfile(self.checkpoint_name):
            print(f"Resuming training from checkpoint: {self.checkpoint_name}")
            checkpoint = torch.load(self.checkpoint_name)
            model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            self.loss_fn = checkpoint["loss_fn"]
            self.history = checkpoint["history"]
            previous_epochs = checkpoint["epoch"]
            best_loss = checkpoint["best_loss"]
        else:
            print(f"No checkpoint found, using default parameters...")

        for epoch in range(previous_epochs + 1, epochs + 1):
            print(f"\nEpoch {epoch}/{epochs}:")
            train_loss = self.train(model)
            test_loss = self.test(model)

            if(self.useWandB):
                wandb.log({"train_loss": train_loss, "test_loss": test_loss},step = epoch)
            
            self.history["train_loss"].append(train_loss)
            self.history["test_loss"].append(test_loss)

            # Save checkpoint only if the validation loss improves (avoid overfitting)
            if (
                best_loss is None
                or (test_loss < best_loss and self.loss_mode == "min")
                or (test_loss > best_loss and self.loss_mode == "max")
            ):
                print(f"Validation loss improved from {best_loss} to {test_loss}.")
                print(f"Saving checkpoint to: {self.checkpoint_name}")
                best_loss = test_loss

                checkpoint_data = {
                    "epoch": epoch,
                    "best_loss": best_loss,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "loss_fn": self.loss_fn,
                    "history": self.history,
                }
                torch.save(checkpoint_data, f'{self.checkpoint_name}_shobhit')
       
        if(self.useWandB):
            wandb.finish()

        return self.history

    def train(self, model):
        total_loss = 0.0
        model.train()
        with tqdm(self.train_loader) as progress:
            for i, (mixture, sources) in enumerate(progress):
                mixture = mixture.to(self.device)
                sources = sources.to(self.device)

                self.optimizer.zero_grad()

                predictions = model(mixture)
                loss = self.loss_fn(predictions, sources)

                if self.loss_mode == "max":  # To optimize for maximization, multiply by -1
                    loss = -1 * loss

                loss.mean().backward()

                # Gradient Value Clipping
                if self.gradient_clipping:
                    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)

                self.optimizer.step()

                total_loss += loss.mean().item()

                if i % self.display_freq == 0:
                    progress.set_postfix(
                        {
                            "loss": float(total_loss / (i + 1)),
                        }
                    )

        total_loss /= len(self.train_loader)
        return total_loss

    def test(self, model):
        total_loss = 0.0
        model.eval()
        with torch.no_grad():
            with tqdm(self.val_loader) as progress:
                for i, (mixture, sources) in enumerate(progress):
                    mixture = mixture.to(self.device)
                    sources = sources.to(self.device)

                    predictions = model(mixture)

                    loss = self.loss_fn(predictions, sources)

                    total_loss += loss.mean().item()

                    if i % self.display_freq == 0:
                        progress.set_postfix(
                            {
                                "loss": float(total_loss / (i + 1)),
                            }
                        )

        total_loss /= len(self.val_loader)
        return total_loss

In [6]:
# ap = argparse.ArgumentParser()
# ap.add_argument("--clean_train_path",default= "/datasets/Train/cleanSliced" )
clean_train_path = "/datasets/Train/cleanSliced"

# ap.add_argument("--clean_val_path",default = "/datasets/Validation/cleanSliced")
clean_val_path = "/datasets/Validation/cleanSliced"

# ap.add_argument("--noise_train_path",default = "/datasets/Train/noisySliced")
noise_train_path = "/datasets/Train/noisySliced"

# ap.add_argument("--noise_val_path",default = "/datasets/Validation/noisySliced")
noise_val_path = "/datasets/Validation/noisySliced"
# ap.add_argument("--keep_rate", default=0.8, type=float)
keep_rate = 0.8

# Model checkpoint
model = "ConvTasNet"

# ap.add_argument("--model",default="ConvTasNet", choices=["UNet", "UNetDNP", "ConvTasNet", "TransUNet", "SepFormer"])
# ap.add_argument("--checkpoint_name", default= "doesntExist.tar",help="File with .tar extension")
checkpoint_name = "doesntExist.tar"

# Training params
epochs = 100
batch_size = 16
lr = 1e-4
gradient_clipping = True
# ap.add_argument("--epochs", default=100, type=int)
# ap.add_argument("--batch_size", default=16, type=int)
# ap.add_argument("--lr", default=1e-4, type=float)
# ap.add_argument("--gradient_clipping", action="store_true")

# GPU setup
# ap.add_argument("--gpu", default="-1")
# ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")

# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

# visible_devices = list(map(lambda x: int(x), args.gpu.split(",")))
# print("Visible devices:", visible_devices)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type('torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor')
print(f"Using device: {device}")

# from torchaudio.models import ConvTasNet

# from losses import LogSTFTMagnitudeLoss, MultiResolutionSTFTLoss, ScaleInvariantSDRLoss, SpectralConvergenceLoss
# from models import *

# Select the model to be used for training
training_utils_dict = get_model("ConvTasNet")

model = training_utils_dict["model"]
data_mode = training_utils_dict["data_mode"]
loss_fn = training_utils_dict["loss_fn"]
loss_mode = training_utils_dict["loss_mode"]

# model = torch.nn.DataParallel(model, device_ids=list(range(len(visible_devices))))
model = model.to(device)

from data import AudioDirectoryDataset, NoiseMixerDataset

train_data = NoiseMixerDataset(
    clean_dataset=AudioDirectoryDataset(root=clean_train_path, keep_rate=keep_rate),
    noise_dataset=AudioDirectoryDataset(root=noise_train_path, keep_rate=keep_rate),
    mode=data_mode,
)

val_data = NoiseMixerDataset(
    clean_dataset=AudioDirectoryDataset(root=clean_val_path, keep_rate=keep_rate),
    noise_dataset=AudioDirectoryDataset(root=noise_val_path, keep_rate=keep_rate),
    mode=data_mode,
)

useWandB = False

# trainer = Trainer(train_data, val_data, checkpoint_name=args.checkpoint_name,useWandB = useWandB)
# history = trainer.fit(
#     model,
#     device,
#     epochs=args.epochs,
#     batch_size=args.batch_size,
#     lr=args.lr,
#     loss_fn=loss_fn,
#     loss_mode=loss_mode,
#     gradient_clipping=args.gradient_clipping,
# )


Using device: cuda


In [21]:
import sounddevice as sd
import IPython.display as ipd
mixture, sources = train_data[0]
print(mixture.shape)
sample_rate = 16000
# ipd.Audio(mixture, rate=sample_rate)

# ipd.Audio(sources[0], rate=sample_rate) 

torch.Size([1, 64000])


In [32]:
from pystoi import stoi
from pesq import pesq

trainingSample = []
cleanSample = []
for i,(mixture, sources) in enumerate(train_data):
    if(i == 0):
        trainingSample = np.array(mixture)
        cleanSample = np.array(sources[0])
    elif i == 3:
        break
    else:
     print(f'trainingSample shape: {trainingSample.shape}')
     print(f'cleanSample shape: {cleanSample.shape}')
     trainingSample = np.concatenate([trainingSample,mixture],axis=1)
     cleanSample = np.concatenate([cleanSample,sources[0]],axis=0)

# ipd.Audio(trainingSample, rate=sample_rate)
ipd.Audio(cleanSample, rate=sample_rate)


pesq_score = pesq(fs = sample_rate,ref = cleanSample,deg = np.squeeze(trainingSample))
print(f'Pesq Score: {pesq_score}')

stoi_score = stoi(x=cleanSample,y = np.squeeze(trainingSample),fs_sig= sample_rate,extended=True)
print(f'EStoi Score: {stoi_score}')


trainingSample shape: (1, 64000)
cleanSample shape: (64000,)
trainingSample shape: (1, 128000)
cleanSample shape: (128000,)
Pesq Score: 1.5887751579284668
Stoi Score: 0.8173428541857467


In [None]:
trainer = Trainer(train_data, val_data, checkpoint_name=args.checkpoint_name,useWandB = useWandB)
history = trainer.fit(
    model,
    device,
    epochs=args.epochs,
    batch_size=args.batch_size,
    lr=args.lr,
    loss_fn=loss_fn,
    loss_mode=loss_mode,
    gradient_clipping=args.gradient_clipping,
)
