In [15]:
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt

%reload_ext autoreload
%autoreload 2

In [61]:
def visualize_complex_tensor(tensor):
    if tensor.isinstance(torch.Tensor):
        tensor = tensor.detach()
    else:
        pass
    # Separate the real and imaginary parts of the complex tensor
    real = tensor.abs()
    imag = tensor.angle()
    amptodb = torchaudio.transforms.AmplitudeToDB()

    # Create a grid of subplots for real and imaginary parts
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

    # Plot the real part
    im1 = axs[0].imshow(amptodb(real), cmap='magma')
    axs[0].set_title('Mag')

    # Plot the imaginary part
    im2 = axs[1].imshow(imag, cmap='Blues')
    axs[1].set_title('Phase')

    # Add color bar
    fig.colorbar(im1, ax=axs[0])
    fig.colorbar(im2, ax=axs[1])

    # Show the plot
    plt.show()


def frequency_shift(stft, sr, N=513, L=512):
    import numpy as np
    
    # Create a time-frequency grid for the frequency shift
    k = np.arange(N)[:, None]  # Frequency bin indices, column vector
    l = np.arange(stft.shape[1])  # Frame indices
    
    freq_shift = np.exp(-1j * (2 * np.pi/ N) * k * l * L )
    
    # Shift to baseband
    baseband_stft_matrix = stft * freq_shift
    
    # Calculate and unwrap phase
    phase = np.angle(baseband_stft_matrix)
    unwrapped_phase = np.unwrap(phase, axis=1)  # Unwrap phase along the time axis
    
    # Calculate phase difference between neighboring frames
    phase_diff = np.diff(unwrapped_phase, axis=1)
    
    return phase_diff




In [17]:
class convGLU(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, kernel_size=(7,7), padding='same', batchnorm=False):
        super().__init__()
        if padding == 'same':
            padding = (kernel_size[0]//2, kernel_size[1]//2)
        self.conv = nn.Conv2d(in_channels, out_channels * 2, kernel_size, padding=padding)  # 2D convolutional layer
        self.sigmoid = nn.Sigmoid()  # Sigmoid activation function
        if batchnorm:
            self.conv = nn.Sequential(
                self.conv,
                nn.BatchNorm2d(out_channels * 2)  # Batch normalization layer
            )

    def forward(self, x):
        x = self.conv(x)  # Apply convolutional layer
        channel = x.shape[1]  # Get the number of channels

        x = x[:, :channel//2, :, :] * self.sigmoid(x[:, channel//2:, :, :])  # Apply GLU (Gated Linear Unit)
        return x
    

class DNN(nn.Module):
    def __init__(self,padding=None,additional_conv=False):
        super().__init__()
        self._hidden_channels = 32
        self.initial = nn.Sequential(
            nn.Conv2d(6, self._hidden_channels, (11,11), padding=padding) if additional_conv else nn.Identity(), # in_channel = 6 because we concatenate the real and imag part of the complex spectrogram
            convGLU(self._hidden_channels if additional_conv else 6, self._hidden_channels, (11,11), padding='same'))
        
        self.mid = nn.Sequential(
            nn.Conv2d(self._hidden_channels, self._hidden_channels, (7,3), padding=(7//2, 3//2)) if additional_conv else nn.Identity(),
            convGLU(self._hidden_channels, self._hidden_channels, (7,3), padding='same'),
            nn.Conv2d(self._hidden_channels, self._hidden_channels, (7,3), padding=(7//2, 3//2)) if additional_conv else nn.Identity(),
            convGLU(self._hidden_channels, self._hidden_channels, (7,3), padding='same'),
        )

        self.final = nn.Sequential(
            nn.Conv2d(self._hidden_channels, 1, (7,3), padding=(7//2, 3//2)) if additional_conv else nn.Identity(),
            convGLU(self._hidden_channels,self._hidden_channels, (7,3), padding='same'),
            nn.Conv2d(self._hidden_channels, 2, (7,3), padding=(7//2, 3//2)),
        )

    def forward(self, x):
        # x = torch.cat([x,y,z],dim=1)
        x = self.initial(x)
        residual = x
        x = self.mid(x)
        x += residual
        x = self.final(x)
        return x


class DeepGriffinLim(nn.Module):
    def __init__(self,blocks=10, n_fft=1024, hop_size=512, win_size=1024, window='hann_window'):
        super().__init__()
        self.dnn_blocks = nn.ModuleList([DNN() for _ in range(blocks)]) # DNN blocks


    def stft(self, x, n_fft=1024, hop_size=512, win_size=1024):
        return torch.stft(x, n_fft=n_fft, hop_length=hop_size, win_length=win_size, return_complex=True)

    def istft(self, x, n_fft=1024, hop_size=512, win_size=1024):
        return torch.istft(x, n_fft=n_fft, hop_length=hop_size, win_length=win_size)


    ############################################
    # EXPERIMENTAL FUNCTIONS ###################
    ############################################
    def swap_in_mag(self,mag, x):
        phase = torch.angle(x)
        real_part = mag * torch.cos(phase)
        imaginary_part = mag * torch.sin(phase)
        new_tensor = torch.cat([real_part, imaginary_part], dim=-1)  # adjust dim as per your needs
        return new_tensor

    def final_mag_swap(self, mag, x):
        phase = torch.angle(x)
        return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1), torch.cat([mag, phase], dim=1)



    def swap_in_true_mag(self, x, mag,final=False):
        phase = torch.angle(x)
        # return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1) # legacy
        if final:
            return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1)
        else:
            return  mag + 1j * phase

    ############################################
    ############################################
    ############################################


    def forward(self, x, mag, added_depth=1):
        layer_outputs = []

        for _ in range(added_depth):
            for subblock in self.dnn_blocks:
                # step one is the P_a step
                
                y_tilda = self.swap_in_mag(x_tilda, mag)
                z_tilda = self.stft(self.istft(y_tilda.squeeze(1)))

                dnn_in = self.transform_to_float([x_tilda, y_tilda, z_tilda.unsqueeze(1)])
                dnn_out = subblock(dnn_in)

                residual  = dnn_out[:,0,...] + 1j * dnn_out[:,1,...]

                x = (z_tilda - residual).unsqueeze_(1)

                layer_outputs.append(x)

        return layer_outputs, self.swap_in_true_mag(x, mag,final=True), residual

 
    @staticmethod
    def transform_to_float(tensor_list: list):
        output = []
        for idx, i in enumerate(range(len(tensor_list))):
            if tensor_list[i].dtype == torch.complex64 and tensor_list[i].dim() == 4:
                output.append(torch.cat([tensor_list[i].real, tensor_list[i].imag], dim=1))
            else:
                print(f'Input {idx} is not a complex tensor with 4 dimensions')
                return None

        return torch.cat(output, dim=1)
        



In [18]:
# DGL = DeepGriffinLim(blocks=10)

# mag = torch.randn(4,1,513,50,dtype=torch.float32)
# x = torch.randn(4,1,513,50,dtype=torch.complex64)

# tens = DGL(x,mag)

# print(tens.dtype)

# phase = torch.angle(x)

# sample = torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1)
# torch.nn.MSELoss(reduction='mean')(sample, tens)



In [19]:
# def swap_in_true_mag(x, mag):
#     '''
#     For the short time fourier transform, to work with batched data
#     we need to remove the channel dimension and add it back after. We
#     still need an 4D tensor for the DNN to work properly. This means,
#     that the dimension needs to be reduced after the swap and expanded
#     again before the DNN. This entire cell is just a conformation, that
#     this is the necessary proceedure. 
#     '''
#     phase = torch.angle(x)
#     # return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1) # legacy
#     return  (mag + 1j * phase).squeeze_(1)


# x = torch.randn(16,1,513,1024, dtype=torch.complex64)
# mag = torch.randn(16,1,513,1024, dtype=torch.float32)

# out = swap_in_true_mag(x,mag)
# print(f'OUT SHAPE: {out.shape}')
# print(f'OUT TYPE: {out.dtype}')


# istft = torch.istft(out, n_fft=1024, hop_length=512, win_length=1024)
# print(f'ISTFT shape: {istft.shape}')
# print(f'ISTFT dtype: {istft.dtype}')

# stft = torch.stft(istft, n_fft=1024, hop_length=512, win_length=1024,return_complex=True)
# print(f'STFT shape: {stft.shape}')
# print(f'STFT dtype: {stft.dtype}')

# # print(f'unsqueezed stft shape: {stft.unsqueeze_(1).shape}')

In [41]:
from DGLim.data import *
from torch.utils.data import DataLoader

ds = AvianNatureSounds(annotation_file_path=hp.annotation_file_path,
                       root_dir=hp.root_dir,
                       key=hp.key,
                       mode=hp.mode,
                       length=15,
                       sampling_rate=hp.sampling_rate,
                       n_fft=hp.n_fft,
                       hop_length=hp.hop_length,
                       mel_spectrogram=hp.mel_spectrogram,
                       verbose=hp.verbose,
                       fixed_limit=False)


# train_data = DataLoader(ds, batch_size=hp.batch_size, shuffle=True, num_workers=hp.num_workers)
# train_data = DataLoader(ds, batch_size=8, shuffle=True, num_workers=hp.num_workers)

# batch = next(iter(train_data))

# for idx, (complex, mag, label) in enumerate(train_data):
#     print(idx, complex.shape)
#     print(idx, mag.shape)
#     print(idx, label)
#     break

In [21]:
# DGL = DeepGriffinLim(blocks=4)

# x = batch[0]
# mag = batch[1]

# print(f'x shape: {x.shape}')
# print(f'x dtype: {x.dtype}')

# print(f'mag shape: {mag.shape}')
# print(f'mag dtype: {mag.dtype}')

# DGL(x,mag)

In [42]:
train_data = DataLoader(ds, batch_size=1, shuffle=True, num_workers=hp.num_workers)

batch = next(iter(train_data))

comp, mag, label = batch

In [62]:
# Simple Training loop
from tqdm import tqdm
model = DeepGriffinLim(blocks=10)
train_data = DataLoader(ds, batch_size=8, shuffle=True, num_workers=hp.num_workers)

def train(model, data_loader, device,epochs=10):
    


    model.train()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    criterion = torch.nn.L1Loss(reduction='None')


    batch = next(iter(train_data))

    # Initialise the noisy signal
    comp, mag, label = batch
    noise = torch.randn_like(comp, dtype=torch.complex64)
    comp = comp + noise


    for epoch in range(epochs):

        comp, mag, label = batch

        noise = torch.randn_like(comp, dtype=torch.complex64)
        comp = comp + noise
        comp = comp.to(device)
        mag = mag.to(device)
        
        layer_outputs, final_out, residual = model(x=comp, mag=mag)

        loss = criterion(residual, dnn_out)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # if epoch % 2 == 0:
        #     print(f'Loss: {loss.item()}')
        # if epoch % 2 == 0:
        print(f'Loss: {loss.item()}')



# train(model, train_data, 'cpu',epochs=10)

In [24]:
x = torch.randn(4,1,513,200,dtype=torch.complex64)

y = torch.randn(4,1,513,200,dtype=torch.complex64)

nn.L1Loss()(x,y)

tensor(1.2526)

In [25]:
batch = next(iter(train_data))

In [26]:
print(batch[0][0][0].shape)
print(batch[0][0][0].dtype)

torch.Size([513, 282])
torch.complex64


In [28]:
# model = DeepGriffinLim(blocks=10)

# comp, mag, label = batch

# reconstruct, dnn_out, residual = model(comp,mag)

In [None]:
dnn_out[0][0].shape

In [86]:
import torch
import torchvision
import matplotlib.pyplot as plt

batch = next(iter(train_data))

def frequency_shift(stft, sr, N=513, L=512):
    import numpy as np
    
    # Create a time-frequency grid for the frequency shift
    k = np.arange(N)[:, None]  # Frequency bin indices, column vector
    l = np.arange(stft.shape[1])  # Frame indices
    
    freq_shift = np.exp(-1j * (2 * np.pi/ N) * k * l * L )
    
    # Shift to baseband
    baseband_stft_matrix = stft * freq_shift
    
    # Calculate and unwrap phase
    phase = np.angle(baseband_stft_matrix)
    unwrapped_phase = np.unwrap(phase, axis=1)  # Unwrap phase along the time axis
    
    # Calculate phase difference between neighboring frames
    phase_diff = np.diff(unwrapped_phase, axis=1)
    
    return phase_diff





def visualize_complex_tensor(tensor, phase_only=False):
    tensor = tensor.detach()

    # Separate the real and imaginary parts of the complex tensor
    real = tensor.abs()
    imag = tensor.angle()
    amptodb = torchaudio.transforms.AmplitudeToDB()

    if not phase_only:
        # Create a grid of subplots for real and imaginary parts
        fig, axs = plt.subplots(2, 1, figsize=(20, 6))


        # Plot the real part
        im1 = axs[0].imshow(amptodb(real), cmap='magma')
        axs[0].set_title('Mag')

        # Plot the imaginary part
        im2 = axs[1].imshow(imag, cmap='Blues')
        axs[1].set_title('Phase')

        # Add color barc
        fig.colorbar(im1, ax=axs[0])
        fig.colorbar(im2, ax=axs[1])

        # Show the plot
        plt.show()

        return real, imag
    else:
        fig, axs = plt.subplots(1, 1, figsize=(20, 6))
        im1 = axs.imshow(imag, cmap='magma')
        axs.set_title('Phase')
        fig.colorbar(im1, ax=axs)
        plt.show()

        return imag



# visualize_complex_tensor(reconstruct[0][0])

In [92]:
noise = torch.randn_like(batch[0][0][0], dtype=torch.complex64)

# real , imag = visualize_complex_tensor(batch[0][0][0],phase_only=False)
# imag2 = visualize_complex_tensor(batch[0][0][0],phase_only=True)


# print(imag2.shape)


# phase_diff = frequency_shift(imag, 48000)

# sr = 48000
# plt.figure(figsize=(15,6))
# librosa.display.specshow(phase_diff,
#                          sr=sr,
#                          x_axis='time',
#                          y_axis='linear')
# plt.colorbar(format="%+2.f")
# plt.show()
