In [4]:
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from DGLim.data import *
import librosa
import matplotlib.pyplot as plt

%reload_ext autoreload
%autoreload 2

In [5]:
class convGLU(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, kernel_size=(7,7), padding='same', batchnorm=False):
        super().__init__()
        if padding == 'same':
            padding = (kernel_size[0]//2, kernel_size[1]//2)
        self.conv = nn.Conv2d(in_channels, out_channels * 2, kernel_size, padding=padding)  # 2D convolutional layer
        self.sigmoid = nn.Sigmoid()  # Sigmoid activation function
        if batchnorm:
            self.conv = nn.Sequential(
                self.conv,
                nn.BatchNorm2d(out_channels * 2)  # Batch normalization layer
            )

    def forward(self, x):
        x = self.conv(x)  # Apply convolutional layer
        channel = x.shape[1]  # Get the number of channels

        x = x[:, :channel//2, :, :] * self.sigmoid(x[:, channel//2:, :, :])  # Apply GLU (Gated Linear Unit)
        return x
    

class DNN(nn.Module):
    def __init__(self,padding=None,additional_conv=False):
        super().__init__()
        self._hidden_channels = 32
        self.initial = nn.Sequential(
            nn.Conv2d(7, self._hidden_channels, (11,11), padding=padding) if additional_conv else nn.Identity(), # in_channel = 6 because we concatenate the real and imag part of the complex spectrogram
            convGLU(self._hidden_channels if additional_conv else 7, self._hidden_channels, (11,11), padding='same'))
        
        self.mid = nn.Sequential(
            nn.Conv2d(self._hidden_channels, self._hidden_channels, (7,3), padding=(7//2, 3//2)) if additional_conv else nn.Identity(),
            convGLU(self._hidden_channels, self._hidden_channels, (7,3), padding='same'),
            nn.Conv2d(self._hidden_channels, self._hidden_channels, (7,3), padding=(7//2, 3//2)) if additional_conv else nn.Identity(),
            convGLU(self._hidden_channels, self._hidden_channels, (7,3), padding='same'),
        )

        self.final = nn.Sequential(
            nn.Conv2d(self._hidden_channels, 1, (7,3), padding=(7//2, 3//2)) if additional_conv else nn.Identity(),
            convGLU(self._hidden_channels,self._hidden_channels, (7,3), padding='same'),
            nn.Conv2d(self._hidden_channels, 2, (7,3), padding=(7//2, 3//2)),
        )

    def forward(self, x):
        x = self.initial(x)
        residual = x
        x = self.mid(x)
        x += residual
        x = self.final(x)
        return x


class DGL_block(nn.Module):
    def __init__(self,blocks=10, n_fft=1024, hop_size=512, win_size=1024, window='hann_window'):
        super().__init__()
        self.dnn = DNN()

    def stft(self, x, n_fft=1024, hop_size=512, win_size=1024):
        return torch.stft(x, n_fft=n_fft, hop_length=hop_size, win_length=win_size, return_complex=True)

    def istft(self, x, n_fft=1024, hop_size=512, win_size=1024):
        return torch.istft(x, n_fft=n_fft, hop_length=hop_size, win_length=win_size)

    def magswap(self, mag, x_tilda):
        return mag * x_tilda / torch.abs(x_tilda)

    def forward(self,x_tilda, mag, added_depth=1):
        # step one is the P_a step

        y_tilda = self.magswap(mag=mag,x_tilda=x_tilda)

        z_tilda = self.stft(self.istft(y_tilda.squeeze(1)))

        dnn_in = self.transform_to_float([x_tilda, y_tilda, z_tilda.unsqueeze(1)])
        dnn_in = torch.cat([dnn_in, mag], dim=1)

        dnn_out = self.dnn(dnn_in)
        residual  = torch.complex(dnn_out[:,0,...], dnn_out[:,1,...])

        x_tilda = (z_tilda - residual).unsqueeze_(1)
                
        final = self.magswap(mag=mag,x_tilda=x_tilda)
        return z_tilda.unsqueeze_(1), residual.unsqueeze_(1), final, subblock_out

    @staticmethod
    def transform_to_float(tensor_list: list):
        output = []
        for idx, i in enumerate(range(len(tensor_list))):
            if tensor_list[i].dtype == torch.complex64 and tensor_list[i].dim() == 4:
                output.append(torch.cat([tensor_list[i].real, tensor_list[i].imag], dim=1))
            else:
                print(f'Input {idx} is not a complex tensor with 4 dimensions')
                return None

        return torch.cat(output, dim=1)
        



In [9]:
# simple training loop 
# Hyperparameters

# Hyperparams
epochs = 3
learning_rate = 5e-5
weight_decay = 0.0001

train_data = DataLoader(ds, batch_size=hp.batch_size, shuffle=True, num_workers=hp.num_workers)


# Model 
model = DGL_block()
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.L1Loss(reduction='sum')




batch = next(iter(train_data))

# for epoch in range(epochs):
#     clear, noisy, mag, label = batch
#     # Forward pass
    
#     z_tilda, residual, final, subblock_out = model(x_tilda=noisy, mag=mag)

#     z_tilda_for_loss = convert_from_complex(z_tilda)
#     residual_for_loss = convert_from_complex(residual)
#     clear_for_loss = convert_from_complex(clear)



#     print(z_tilda_for_loss[:,1:,...].shape)

#     # Compute loss
#     loss = criterion(z_tilda_for_loss[:,1:,...] - clear_for_loss[:,1:,...], residual_for_loss[:,1:,...])

#     # Backward pass
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     # Print loss
#     print(f'Epoch: {epoch} Loss: {loss.item()}')

#     phase_2 = torch.angle(final)

#     librosa.display.specshow((phase_2[0][0]).detach().numpy(),
#         sr=48000,
#         x_axis='time',
#         y_axis='linear',
#         # vmin=0,
#         # vmax=44100//2
#         )
#     plt.colorbar(format="%+2.f")
#     plt.show() 

    




NameError: name 'model' is not defined