In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.utils.parse_config import ConfigParser
import json

with open('src/configs/one_batch.json', 'r') as fd:
    config_json = json.load(fd)

config_parser = ConfigParser(config=config_json)



In [3]:
from src.utils.object_loading import get_dataloaders

dataloaders = get_dataloaders(config_parser)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
batch = next(iter(dataloaders['train']))

In [5]:
batch['spectrogram'].size()

torch.Size([3, 80, 255])

In [6]:
batch['wav_real'].size()

torch.Size([3, 1, 8191])

In [9]:
256 / (2 ** 2)

64.0

In [7]:
from src.model.generator import Generator

g = Generator(
    in_features=80,
    k_u=[16, 16, 4, 4],
    h_u=512, 
    k_r=[3,7,11],
    D_r=[
        [[1,1], [3,1], [5,1]],
        [[1,1], [3,1], [5,1]],
        [[1,1], [3,1], [5,1]]
    ]
)



In [8]:
print(g)

Generator(
  (preconv): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))
  (conv_transpose): ModuleList(
    (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
    (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
    (2): ConvTranspose1d(128, 64, kernel_size=(4,), stride=(2,), padding=(1,))
    (3): ConvTranspose1d(64, 32, kernel_size=(4,), stride=(2,), padding=(1,))
  )
  (mrf): ModuleList(
    (0): MRF(
      (res_blocks): ModuleList(
        (0): ResBlock(
          (blocks): ModuleList(
            (0): InnerResBlock(
              (blocks): ModuleList(
                (0-1): 2 x Sequential(
                  (0): LeakyReLU(negative_slope=0.01)
                  (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=same)
                )
              )
            )
            (1): InnerResBlock(
              (blocks): ModuleList(
                (0): Sequential(
                  (0): LeakyReLU(negative_s

In [9]:
gen_output = g(batch['spectrogram'])

In [10]:
gen_output.size()

torch.Size([3, 1, 65280])

In [11]:
gen_output = gen_output[:, :, :-1]

In [12]:
gen_output.size()

torch.Size([3, 1, 65279])

: 

In [None]:
from src.model.discriminator import MPD, MSD

mpd = MPD()
msd = MSD()

In [None]:
batch['wav_real'].size()

torch.Size([3, 1, 8191])

In [None]:
mpd_real_out, mpd_real_fmaps = mpd(batch['wav_real'])
print('===')
for o in mpd_real_out:
    print(o.size())

Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 990])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 644])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 375])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 243])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 124])
===
torch.Size([3, 990])
torch.Size([3, 644])
torch.Size([3, 375])
torch.Size([3, 243])
torch.Size([3, 124])


In [None]:
mpd_gen_out, mpd_gen_fmaps = mpd(gen_output)
print('===')
for o in mpd_gen_out:
    print(o.size())

Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 990])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 644])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 375])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 243])
Before: torch.Size([3, 1, 8191])
After: torch.Size([3, 124])
===
torch.Size([3, 990])
torch.Size([3, 644])
torch.Size([3, 375])
torch.Size([3, 243])
torch.Size([3, 124])


In [None]:
msd_real_out, msd_real_fmaps = msd(batch['wav_real'])
print('===')
for o in msd_real_out:
    print(o.size())

===
torch.Size([3, 128])
torch.Size([3, 64])
torch.Size([3, 33])


In [None]:
msd_gen_out, msd_gen_fmaps = msd(gen_output)
print('===')
for o in msd_gen_out:
    print(o.size())

===
torch.Size([3, 128])
torch.Size([3, 64])
torch.Size([3, 33])


In [None]:
from src.loss.gan_loss import GLoss, DLoss

In [None]:
generator_loss = GLoss(
    mel_config={
        "sample_rate": 22050,
        "win_length": 1024,
        "hop_length": 256,
        "n_fft": 1024,
        "f_min": 0,
        "f_max": 8000,
        "n_mels": 80,
        "power": 1.0
    },
    lambda_mel=0.5,
    lambda_fm=0.5
)
generator_loss(
    wav_generated=gen_output, wav_real=batch['wav_real'], 
    mpd_features_generated=mpd_gen_fmaps, mpd_features_real=mpd_real_fmaps,
    msd_features_generated=msd_gen_fmaps, msd_features_real=msd_real_fmaps,
    mpd_d_out_generated=mpd_gen_out, mpd_d_out_real=mpd_real_out,
    msd_d_out_generated=msd_gen_out, msd_d_out_real=msd_real_out
)

tensor(28.4476, grad_fn=<AddBackward0>)

In [None]:
discriminator_loss = DLoss()
discriminator_loss(
    msd_d_out_generated=msd_gen_out, msd_d_out_real=msd_real_out,
    mpd_d_out_generated=mpd_gen_out, mpd_d_out_real=mpd_real_out
)

tensor(8.0743, grad_fn=<AddBackward0>)