In [1]:
import numpy as np
import math, random
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader, Dataset
import os
from asteroid.metrics import get_metrics
from asteroid.losses import pairwise_neg_sisdr, pairwise_neg_snr, singlesrc_neg_sisdr, singlesrc_neg_snr

import numpy as np
import matplotlib.pyplot as plt
from glob import glob

from IPython.display import Audio
from tqdm import tqdm
import time

In [2]:
datapath = 'C:/Users/USER/Desktop/all_mono_1'

In [None]:
## 필요할 때만
audio, rate = torchaudio.load(os.path.join(datapath, 'all_mono.wav'), normalize=False)
# seconds = 2
# for idx, i in enumerate(range(0, audio.shape[-1], rate * seconds)):
#     torchaudio.save(os.path.join(datapath, f'all_mono_{idx}.wav'), audio[:, i:i+rate*seconds], sample_rate=rate)

In [None]:
def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid
    

class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias) # wx + b => convolution 또는 다른 layer로 변경, model complexity
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


class reluLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return F.relu(self.omega_0 * self.linear(input))
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords        

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = {}

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                    
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else: 
                x = layer(x)
                
                if retain_grad:
                    x.retain_grad()
                    
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations


class AudioFile(torch.utils.data.Dataset):
    def __init__(self, filename):
        self.data, self.rate = torchaudio.load(filename, normalize=False)
        self.data = self.data.transpose(-1,-2).numpy()[...,0]
        # self.data = self.data.astype(np.float16)
        self.timepoints = get_mgrid(len(self.data), 1)

    def get_num_samples(self):
        return self.timepoints.shape[0]

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        amplitude = self.data
        amplitude = (amplitude / 32768.)
        amplitude = torch.Tensor(amplitude).view(-1, 1)

        amplitude = amplitude / amplitude.abs().max()
        return self.timepoints, amplitude

def spectrogram(wav):
    stft = torchaudio.transforms.Spectrogram(n_fft=1024)(wav)[0]
    stft = torchaudio.transforms.AmplitudeToDB(top_db=80)(stft).numpy()
    stft = stft[::-1]
    fig = plt.figure(figsize=(20,10))
    im = plt.imshow(stft)
    plt.colorbar(im)

In [None]:
wavpath = sorted(glob(os.path.join(datapath, f'all_mono_*.wav')), key=lambda x: int(os.path.basename(x).split('.')[0].split('_')[-1]))

In [None]:
class Siren2(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net1 = SineLayer(in_features, hidden_features, is_first=True, omega_0=first_omega_0)
        self.net2 = SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0)
        self.net3 = SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0)
        self.net4 = SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0)
        
        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.final = final_linear
        else:
            self.final = SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0)
        
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        
        net1 = self.net1(coords)
        net2 = self.net2(net1) + net1
        net3 = self.net3(net2) + net2
        net4 = self.net4(net3) + net3
        output = self.final(net4)
        return output, coords

total_steps = 2000
steps_til_summary = 1000
lr = 1e-4
# schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, 100, T_mult=1, eta_min=lr / 1000, last_epoch=- 1, verbose=False)
# decay = torch.optim.lr_scheduler.StepLR(optim, total_steps, (1 / 1000) ** (1 / total_steps))

pesq = []
losses = []
name = 'Siren2_15k_128_SNR_recons_'

for wave in wavpath:
    bach_audio = AudioFile(wave)

    audio_siren = Siren2(in_features=1, out_features=1, hidden_features=128, 
                        hidden_layers=3, first_omega_0=15000, hidden_omega_0=200, outermost_linear=True) # quantization(다니엘 코드가 좋습니다.), hidden_features 줄이면서 hidden_layers 조절
    audio_siren.cuda()
    dataloader = DataLoader(bach_audio, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)
    
    model_input, ground_truth = next(iter(dataloader))
    model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

    optim = torch.optim.Adam(lr=lr, params=audio_siren.parameters())
    decay = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=1/2**0.5, patience=20, verbose=False)
    # optim = torch.optim.AdamW(lr=1e-4, params=audio_siren.parameters())
    
    minloss = torch.inf
    with tqdm(range(total_steps)) as pbar:
        for step in pbar:
            optim.zero_grad()
            model_output, coords = audio_siren(model_input)    
            loss = singlesrc_neg_snr(model_output.squeeze(-1), ground_truth.squeeze(-1))
            #loss = F.mse_loss(model_output, ground_truth)
            
            pbar.set_postfix({'loss': loss.item()})

            loss.backward()
            optim.step()
            # schedule.step(step)
            decay.step(loss.item())
            if step > int(total_steps * 0.9) and minloss >= loss.item():
                minloss = loss.item()
                best = audio_siren.state_dict()
                # torch.save(best, os.path.join(datapath, f'{name}_best.pt'))
    losses.append(minloss)
    audio_siren.load_state_dict(best)
    audio_siren.eval()
    with torch.no_grad():
        model_output, _ = audio_siren(model_input)
        
    model_output = model_output.float()
    ground_truth = ground_truth.float()
    torchaudio.save(os.path.join(datapath, name + os.path.basename(wave)), model_output.squeeze(-1).cpu().float(), sample_rate=rate)
    model_output = torchaudio.functional.resample(model_output.squeeze(-1), rate, 16000).squeeze().cpu()
    model_input = torchaudio.functional.resample(model_input.squeeze(-1), rate, 16000).squeeze().cpu()
    
    # model_output, _ = torchaudio.load(os.path.join(datapath, 'recons' + os.path.basename(wave)))
    model_output = model_output.squeeze().numpy()
    ground_truth = torchaudio.functional.resample(ground_truth.squeeze(-1), rate, 16000).squeeze().cpu()
    pesq.append(get_metrics(model_output, ground_truth.numpy(), model_output, sample_rate=16000, metrics_list=['pesq'])['pesq'])
    print(pesq[-1])

print(max(pesq), min(pesq), np.mean(pesq))
plt.scatter(np.arange(len(pesq)), pesq)
plt.plot(np.ones_like(pesq) * 3)
plt.plot(np.ones_like(pesq) * 4)
print(pesq)

In [None]:
aa = [i.shape for i in audio_siren.state_dict().values()]
def f(x):
    res = 1.
    for i in x:
        res *= i
    return res
aa = sum([f(i) for i in aa])
# 압축률: 원본 bitrate / 압축본 bitrate, BPS(Bit per second)
768 / (aa * 32 / 1 / 1000)

In [None]:
aa

In [None]:
(aa * 32 / 1 / 1000)

In [None]:
num_params = sum(p.numel() for p in audio_siren.parameters() if p.requires_grad)

In [None]:
num_params

In [None]:
bb = 799170

In [None]:
768 / (bb * 32 / 1 / 1000)

In [None]:
(bb * 32 / 1 / 1000)

In [None]:
def plot_waveform(waveform, sample_rate):
    waveform = waveform
    
    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")
    plt.show(block=False)

In [None]:
import torchaudio.transforms as T
import librosa

def plot_spectrogram(specgram, title=None, ylabel="freq_bin"):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Spectrogram (db)")
    axs.set_ylabel(ylabel)
    axs.set_xlabel("frame")
    im = axs.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto")
    fig.colorbar(im, ax=axs)
    plt.show(block=False)
    
n_fft = 1024
win_length = None
hop_length = 512

# Define transform
spectrogram = T.Spectrogram(
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",
    power=2.0,
)

In [None]:
# 1초로 학습

In [None]:
path = 'C:/Users/USER/Desktop/continuous-audio-representations-main/results/default/SPEECHCOMMANDS/wavegan/autodecoder/audio/'

In [None]:
audio, rate = torchaudio.load(path + 'original_3.wav')

In [None]:
import IPython
IPython.display.Audio(data=audio, rate=16000)

In [None]:
spec = spectrogram(audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
recon_audio, rate = torchaudio.load(path + 'reconstruction_epoch_10000_3.wav')

In [None]:
import IPython
IPython.display.Audio(data=recon_audio, rate=16000)

In [None]:
spec = spectrogram(recon_audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
audio, rate = torchaudio.load('C:/Users/Yoon/Desktop/denoising1sec/Siren2_15k_200_SNR_reconsall_mono_0.wav')

In [None]:
import IPython
IPython.display.Audio(data=audio, rate=48000)

In [None]:
plot_waveform(audio,48000)

In [None]:
spec = spectrogram(audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
# 2초로 학습

In [None]:
audio, rate = torchaudio.load('C:/Users/USER/Desktop/siren_2sec/Siren2_15k_128_SNR_recons_all_mono_0.wav')

In [None]:
spec = spectrogram(audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
aa = [i.shape for i in Siren(in_features=1, out_features=1, hidden_features=128, hidden_layers=6, first_omega_0=15000, hidden_omega_0=200, outermost_linear=True).state_dict().values()]
def f(x):
    res = 1.
    for i in x:
        res *= i
    return res
aa = sum([f(i) for i in aa])
# 압축률: 원본 bitrate / 압축본 bitrate, BPS(Bit per second)
768 / (aa * 32 / 10 / 1000)

In [None]:
# 5초로 학습

In [None]:
audio, rate = torchaudio.load('C:/Users/Yoon/Desktop/denoising5sec/all_mono_0.wav')

In [None]:
plot_waveform(audio,48000)

In [None]:
spec = spectrogram(audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
audio, rate = torchaudio.load('C:/Users/Yoon/Desktop/denoising5sec/Siren2_15k_200_MSE_reconsall_mono_0.wav')

In [None]:
plot_waveform(audio,48000)

In [None]:
spec = spectrogram(audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
# 10초로 학습

In [None]:
audio, rate = torchaudio.load('C:/Users/Yoon/Desktop/denoising/all_mono_0.wav')

In [None]:
plot_waveform(audio,48000)

In [None]:
spec = spectrogram(audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
audio, rate = torchaudio.load('C:/Users/Yoon/Desktop/denoising/Siren2_15k_200_MSE_reconsall_mono_0.wav')

In [None]:
audio.shape

In [None]:
audio[:,0:240000].shape

In [None]:
plot_waveform(audio,48000)

In [None]:
spec = spectrogram(audio)
plot_spectrogram(spec[0], title="torchaudio")

In [None]:
# 오디오 합치기

In [None]:
path = 'C:/Users/Yoon/Desktop/siren_1sec_recon'

In [None]:
wavpath = sorted(glob(os.path.join(datapath, f'Siren2_15k_200_SNR_reconsall_mono_*.wav')), key=lambda x: int(os.path.basename(x).split('.')[0].split('_')[-1]))

In [None]:
list = []
for wave in wavpath:
    audio, rate = torchaudio.load(wave)
    list.append(audio)

In [None]:
len(list)

In [None]:
pred = torch.cat(list,dim=1)

In [None]:
import IPython
IPython.display.Audio(data=pred, rate=48000)

In [None]:
plot_waveform(pred,48000)

In [None]:
spec = spectrogram(pred)
plot_spectrogram(spec[0], title="torchaudio")