In [1]:
import os

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

class UNetDenoiser(nn.Module):
    def __init__(self, in_channels=1, base_channels=64):
        super().__init__()

        # encoder
        self.enc1 = nn.Sequential(
            nn.Conv1d(in_channels, base_channels, 15, stride=1, padding=7),
            nn.ReLU(),
            nn.Conv1d(base_channels, base_channels, 15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels),
            nn.ReLU(),
        )
        self.enc2 = nn.Sequential(
            nn.Conv1d(base_channels, base_channels*2, 15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*2),
            nn.ReLU(),
        )
        self.enc3 = nn.Sequential(
            nn.Conv1d(base_channels*2, base_channels*4, 15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*4),
            nn.ReLU(),
        )

        # bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv1d(base_channels*4, base_channels*4, 15, padding=7),
            nn.ReLU()
        )

        # decoder
        self.dec3 = nn.Sequential(
            nn.ConvTranspose1d(base_channels*4, base_channels*2, 15, stride=2, padding=7, output_padding=1),
            nn.BatchNorm1d(base_channels*2),
            nn.ReLU(),
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose1d(base_channels*4, base_channels, 15, stride=2, padding=7, output_padding=1),
            nn.BatchNorm1d(base_channels),
            nn.ReLU(),
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose1d(base_channels*2, 1, 15, stride=2, padding=7, output_padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # x: (B, 1, T)
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        b = self.bottleneck(e3)
        d3 = self.dec3(b)
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        return d1


class MultiResSTFTLoss(nn.Module):
    def __init__(self, fft_sizes=[512, 1024, 2048], hop_sizes=[128, 256, 512], win_lengths=[512, 1024, 2048]):
        super().__init__()
        self.fft_sizes = fft_sizes
        self.hop_sizes = hop_sizes
        self.win_lengths = win_lengths

    def forward(self, x, y):
        loss = 0.0
        for fft, hop, win in zip(self.fft_sizes, self.hop_sizes, self.win_lengths):
            X = torch.stft(x.squeeze(1), fft, hop, win, return_complex=True)
            Y = torch.stft(y.squeeze(1), fft, hop, win, return_complex=True)
            magX, magY = torch.abs(X), torch.abs(Y)
            loss += torch.mean(torch.abs(magX - magY)) / len(self.fft_sizes)
        return loss


def train_denoiser(model, dataloader, optimizer, device):
    l1_loss = nn.L1Loss()
    stft_loss = MultiResSTFTLoss().to(device)
    model.train()

    for noisy, clean in dataloader:
        noisy = noisy.to(device)
        clean = clean.to(device)

        pred = model(noisy)
        loss_l1 = l1_loss(pred, clean)
        loss_stft = stft_loss(pred, clean)
        loss = loss_l1 + 0.5 * loss_stft

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss.item()


In [3]:
from torch.utils.data import Dataset
import soundfile as sf
import numpy as np

class DenoiseDataset(Dataset):
    def __init__(self, pairs, sr=16000):
        self.pairs = pairs
        self.sr = sr

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        clean_path, noisy_path = self.pairs[idx]
        clean, _ = sf.read(clean_path)
        noisy, _ = sf.read(noisy_path)
        clean = np.mean(clean, axis=-1) if clean.ndim > 1 else clean
        noisy = np.mean(noisy, axis=-1) if noisy.ndim > 1 else noisy

        clean = torch.tensor(clean).float().unsqueeze(0)
        noisy = torch.tensor(noisy).float().unsqueeze(0)

        # Pad or crop to same length
        min_len = min(clean.size(-1), noisy.size(-1))
        clean, noisy = clean[..., :min_len], noisy[..., :min_len]

        return noisy, clean


In [None]:
import os

clean_dir = '../../data/denoise/output_pairs/clean'
noisy_dir = '../../data/denoise/output_pairs/noisy'

train_pairs = []
for entry in os.scandir(clean_dir):
    if entry.is_file() and entry.name.endswith('.wav'):
        clean_path = entry.path
        noisy_path = os.path.join(noisy_dir, entry.name)
        if os.path.exists(noisy_path):
            train_pairs.append((clean_path, noisy_path))
        else:
            print(f"⚠️ No matching noisy file for {entry.name}")


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNetDenoiser().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# example
train_loader = torch.utils.data.DataLoader(DenoiseDataset(train_pairs), batch_size=8, shuffle=True)

for epoch in range(50):
    loss = train_denoiser(model, train_loader, optimizer, device)
    print(f"Epoch {epoch}: Loss = {loss:.4f}")


Epoch 0: Loss = 1.2062
Epoch 1: Loss = 0.6816
Epoch 2: Loss = 0.6104
Epoch 3: Loss = 0.6849
Epoch 4: Loss = 0.4741
Epoch 5: Loss = 0.4404
Epoch 6: Loss = 0.4952
Epoch 7: Loss = 0.5036
Epoch 8: Loss = 0.3677


KeyboardInterrupt: 

In [8]:
print (epoch)


9


In [10]:
current_model = model

In [11]:
for epoch_2 in range(5):
    loss = train_denoiser(model, train_loader, optimizer, device)
    print(f"Epoch {epoch_2}: Loss = {loss:.4f}")


Epoch 0: Loss = 0.2417
Epoch 1: Loss = 0.5908
Epoch 2: Loss = 0.2405
Epoch 3: Loss = 0.2476
Epoch 4: Loss = 0.3441


In [12]:
import torch
import soundfile as sf
import numpy as np

def denoise_audio(model, noisy_path, output_path, device="cpu"):
    model.eval()  # set model to evaluation mode

    # Load audio
    noisy, sr = sf.read(noisy_path)
    if noisy.ndim > 1:
        noisy = np.mean(noisy, axis=-1)  # convert to mono
    noisy = noisy / (np.max(np.abs(noisy)) + 1e-6)  # normalize

    # Convert to tensor
    wave_tensor = torch.tensor(noisy).float().unsqueeze(0).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        denoised_tensor = model(wave_tensor)

    # Convert back to numpy
    denoised_wave = denoised_tensor.squeeze().cpu().numpy()

    # Save output
    sf.write(output_path, denoised_wave, sr)
    print(f"Denoised audio saved to {output_path}")

# Example usage
device = "cuda" if torch.cuda.is_available() else "cpu"
denoise_audio(model, "noised_hrystia_piano.wav", "denoised_music.wav", device)


Denoised audio saved to denoised_music.wav


In [13]:
# Save model and optimizer state
save_path = "unet_denoiser_more_epochs.pt"
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": epoch,
    "loss": loss
}, save_path)

print(f"✅ Model saved to {save_path}")


✅ Model saved to unet_denoiser_more_epochs.pt


In [14]:
# Save model and optimizer state
save_path = "unet_denoiser_less_epochs.pt"
torch.save({
    "model_state_dict": current_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": epoch,
    "loss": loss
}, save_path)

print(f"✅ Model saved to {save_path}")


✅ Model saved to unet_denoiser_less_epochs.pt
