In [50]:
import torch
import torch.nn as nn

# from .istft import InverseSTFT


class ConvGLU(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=(7, 7), padding=None, batchnorm=False):
        super().__init__()
        if not padding:
            padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(in_ch, out_ch * 2, kernel_size, padding=padding)
        if batchnorm:
            self.conv = nn.Sequential(
                self.conv,
                nn.BatchNorm2d(out_ch * 2)
            )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv(x)
        ch = x.shape[1]
        x = x[:, :ch//2, ...] * self.sigmoid(x[:, ch//2:, ...])
        return x


class DeGLI_DNN(nn.Module):
    def __init__(self):
        super().__init__()
        ch_hidden = 16
        self.convglu_first = ConvGLU(6, ch_hidden, kernel_size=(11, 11), batchnorm=True)
        self.two_convglus = nn.Sequential(
            ConvGLU(ch_hidden, ch_hidden, batchnorm=True),
            ConvGLU(ch_hidden, ch_hidden)
        )
        self.convglu_last = ConvGLU(ch_hidden, ch_hidden)
        self.conv = nn.Conv2d(ch_hidden, 2, kernel_size=(7, 7), padding=(3, 3))

    def forward(self, x, mag_replaced, consistent):
        x = torch.cat([x, mag_replaced, consistent], dim=1)
        x = self.convglu_first(x)
        residual = x
        x = self.two_convglus(x)
        x += residual
        x = self.convglu_last(x)
        x = self.conv(x)
        return x


def replace_magnitude(x, mag):
    phase = torch.angle(x) # imag, real
    return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1)


class DeGLI(nn.Module):
    def __init__(self, n_fft: int, hop_length: int, depth=3, out_all_block=True):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.out_all_block = out_all_block

        self.window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False)
        self.dnns = nn.ModuleList([DeGLI_DNN() for _ in range(depth)])


    def stft(self, x):
        return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window,return_complex=True)

    def istft(self, x):
        return torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window)


    def forward(self, x, mag, max_length=None, repeat=1):
        if isinstance(max_length, torch.Tensor):
            max_length = max_length.item()

        out_repeats = []
        for ii in range(repeat):
            for dnn in self.dnns:
                # B, 2, F, T
                mag_replaced = replace_magnitude(x, mag)

                complex_mag_replaced = torch.complex(mag_replaced[:, :1], mag_replaced[:, 1:])

                # print(complex_mag_replaced.shape)

                # B, F, T, 2
                waves = self.istft(complex_mag_replaced.squeeze(1))
                consistent = self.stft(waves)

                consistent = consistent.unsqueeze(1)

                consistent = torch.cat([consistent.real, consistent.imag], dim=1)
                x = torch.cat([x.real, x.imag], dim=1)


                # B, 2, F, T
                # consistent = consistent.permute(0, 3, 1, 2)
                residual = dnn(x, mag_replaced, consistent)
                x = consistent - residual

                x = torch.complex(x[:, :1], x[:, 1:])
                # print(len(out_repeats))

            if self.out_all_block:
                out_repeats.append(x)

        if self.out_all_block:
            out_repeats = torch.stack(out_repeats, dim=1)
        else:
            out_repeats = x.unsqueeze(1)

        final_out = replace_magnitude(x, mag)

        return out_repeats, final_out, residual


In [54]:
from DGLim.data import *

# minimal train loop
# hyperparameters

lr = 3e-4
epochs = 10
batch_size = 4
num_workers = 4
weight_decay = 0.0001

model = DeGLI(n_fft=1024,hop_length=512,out_all_block=False)
critereon = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

train_data = DataLoader(ds, batch_size=batch_size, shuffle=True)

# train loop
model.train()

batch = next(iter(train_data))

for epoch in range(epochs):
    comp, mag, _ = batch
    out_repeats, final_out, residual = model(x=comp, mag=mag)

    signal = torch.cat([mag,torch.angle(comp)], dim=1)
    loss = critereon(final_out - signal, residual)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"loss item: {loss}")
    break

loss item: 15.593008041381836


torch.Size([513, 469])

In [56]:
signal[0][0] == final_out[0][0] 

tensor([[False, False, False,  ..., False, False,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False]])

In [None]:
x = torch.ones((1,2,4,4))
x2 = torch.ones((1,2,4,4))

x - x2

In [None]:
db_convert = torchaudio.transforms.AmplitudeToDB()


ab = db_convert(torch.abs(comp))
ph = torch.angle(comp)


print(ab.shape)
import librosa

librosa.display.specshow(final_out[0][0].detach().numpy(),
                        sr=48000,
                        x_axis='time',
                        y_axis='linear',
                        )
                        
plt.colorbar(format="%+2.f")
plt.show()

In [None]:
visualize_tensor(signal[1][1],convert=False)