In [105]:
import os
import json
import wandb
import itertools
import numpy as np
from time import gmtime, strftime
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torch
import torchaudio
import torch.nn.functional as F

import sys
sys.path.append('../training')
from utils import utils, data_utils, audio_utils
from hyface import Nansy, BShall_Ecapa, BShall_Nimbre

from networks.bshall import AcousticModel
from networks.discriminator import Discriminator
from datasets.loader import Dataset

import IPython.display as ipd

In [111]:
# main_dir = '/home/jaejun/nansy/
config_path = '../training/configs/bshall_pretrained_ecapa.json'
with open(config_path, "r") as f:
    data = f.read()
config = json.loads(data)
args = utils.HParams(**config)

## Data

In [1]:
trainset = Dataset(args, meta_root='../training/filelists', mode='train', datasets=['vctk'], sample_rate=args.data.sample_rate)

NameError: name 'Dataset' is not defined

In [113]:
audio, hubert = trainset[0]
audio.shape, hubert.shape

((64000,), torch.Size([200, 256]))

In [114]:
trainloader = DataLoader(trainset, batch_size=4, collate_fn=trainset.collate)

In [115]:
data = next(iter(trainloader))

In [116]:
data['audio'].shape, data['hubert'].shape, data['frame_lengths'].shape

((4, 64000), torch.Size([4, 256, 200]), (4,))

In [117]:
ipd.Audio(data['audio'][0],rate=16000)

# Model

## Ecapa TDNN test

In [41]:
from speechbrain.pretrained import EncoderClassifier
classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb")

In [42]:
signal = torch.tensor(data['audio'])
embeddings = classifier.encode_batch(signal)
print(signal.shape, embeddings.shape)

torch.Size([4, 64000]) torch.Size([4, 1, 192])


In [43]:
embeddings.repeat(1, 100, 1).shape

torch.Size([4, 100, 192])

In [44]:
temp = ('a','b')
temp

('a', 'b')

## Pretrained AC model

In [45]:
acoustic = torch.hub.load("bshall/acoustic-model:main", "hubert_soft")

Using cache found in /root/.cache/torch/hub/bshall_acoustic-model_main


In [118]:
acoustic.decoder.prenet = nn.Identity()
acoustic.decoder.proj = nn.Identity()
acoustic.decoder

Decoder(
  (prenet): Identity()
  (lstm1): LSTM(768, 768, batch_first=True)
  (lstm2): LSTM(768, 768, batch_first=True)
  (lstm3): LSTM(768, 768, batch_first=True)
  (proj): Identity()
)

In [120]:
saved_state_dict = acoustic.encoder.state_dict()
saved_state_dict = acoustic.decoder.state_dict()

In [124]:
state_dict = bshall.frame_synth.content_encoder.state_dict()
state_dict = bshall.frame_synth.decoder.state_dict()

In [126]:
new_state_dict = {}
for k, v in state_dict.items():
    try:
        new_state_dict[k] = saved_state_dict[k]
    except:
        print("Param {} is not in the checkpoint".format(k))
        new_state_dict[k] = v
# bshall.frame_synth.content_encoder.load_state_dict(new_state_dict)
bshall.frame_synth.decoder.load_state_dict(new_state_dict)

Param proj.weight is not in the checkpoint


<All keys matched successfully>

In [119]:
bshall = BShall_Ecapa(args)

In [90]:
for name, param in bshall.frame_synth.content_encoder.prenet.named_parameters():
    print(name,param.data)
    break

net.0.weight tensor([[-0.0749, -0.0172,  0.0069,  ...,  0.0507, -0.0408, -0.0072],
        [ 0.0968, -0.0022, -0.1447,  ..., -0.0206, -0.0410, -0.0390],
        [ 0.0462,  0.0334, -0.0959,  ...,  0.1804, -0.0050,  0.0129],
        ...,
        [ 0.0493,  0.0238, -0.0126,  ..., -0.0795, -0.0084, -0.0246],
        [-0.0026,  0.0113,  0.0067,  ..., -0.0248,  0.0593, -0.0144],
        [-0.0027,  0.0532,  0.0070,  ...,  0.0134,  0.0716,  0.0114]])


## HyFace

In [7]:
hyface = HyFace(args)
disc = Discriminator(args)



In [46]:
tudio = torch.tensor(audio).unsqueeze(0)
print(tudio.shape)
tudio = torch.randn([4, 87879])
hubert = torch.randn([4, 256, 274])
tudio.shape, hubert.shape

torch.Size([1, 64000])


(torch.Size([4, 87879]), torch.Size([4, 256, 274]))

In [7]:
mel = hyface.logmel(tudio)
mel.shape

torch.Size([4, 80, 549])

In [8]:
timbre_global, timbre_bank = hyface.analyze_timbre(tudio)
timbre_global.shape, timbre_bank.shape

(torch.Size([4, 192]), torch.Size([4, 128, 50]))

In [9]:
contents = torch.cat([hubert, timbre_global[...,None].repeat(1,1,hubert.shape[-1])],dim=1)
timbre_sampled = hyface.timbre.sample_timber(contents, timbre_global, timbre_bank)
timbre_sampled.shape

torch.Size([4, 192, 274])

In [14]:
synth = hyface.synthesize(hubert, timbre_global, timbre_bank)
synth.shape

torch.Size([4, 80, 548])

## BShall

In [20]:
# main_dir = '/home/jaejun/nansy/
config_path = '../training/configs/bshall.json'
with open(config_path, "r") as f:
    data = f.read()
config = json.loads(data)
args = utils.HParams(**config)

In [64]:
tudio = torch.randn([4, 87879])
hubert = torch.randn([4, 256, 274])
timbre = torch.randn([4, 192, 274])
tudio.shape, hubert.shape, timbre_sampled.shape

(torch.Size([4, 87879]), torch.Size([4, 256, 274]), torch.Size([4, 192, 274]))

In [53]:
acmodel = AcousticModel(args)

TypeError: AcousticModel.__init__() missing 3 required positional arguments: 'timbre_dim', 'decoder_hidden', and 'out_dim'

In [88]:
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present

class Encoder(nn.Module):
    def __init__(self, prenet_indim: int, upsample: bool = True):
        super().__init__()
        self.prenet = PreNet(prenet_indim, 256, 256)
        self.convs = nn.Sequential(
            nn.Conv1d(256, 512, 5, 1, 2),
            nn.ReLU(),
            nn.InstanceNorm1d(512),
            nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(),
            nn.Conv1d(512, 512, 5, 1, 2),
            nn.ReLU(),
            nn.InstanceNorm1d(512),
            nn.Conv1d(512, 512, 5, 1, 2),
            nn.ReLU(),
            nn.InstanceNorm1d(512),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.prenet(x)
        x = self.convs(x.transpose(1, 2))
        return x.transpose(1, 2)


class PreNet(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.net = nn.Sequential(
            # nn.Conv1d(input_size, hidden_size, 1),
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            # nn.Conv1d(hidden_size, output_size, 1),
            nn.Linear(hidden_size, output_size),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.lstm1 = nn.LSTM(1024, 1024, batch_first=True)
        self.lstm2 = nn.LSTM(1024, 1024, batch_first=True)
        self.lstm3 = nn.LSTM(1024, 1024, batch_first=True)
        self.proj = nn.Linear(1024, 80, bias=False)
        self.args = args

    def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
        x  = torch.cat((x, mels), dim=-1)
        x, _ = self.lstm1(x)
        res = x
        x, _ = self.lstm2(x)
        x = res + x
        res = x
        x, _ = self.lstm3(x)
        x = res + x
        return self.proj(x)

    @torch.inference_mode()
    def generate(self, xs: torch.Tensor) -> torch.Tensor:
        m = torch.zeros(xs.size(0), 80, device=xs.device)
        h1 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
        c1 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
        h2 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
        c2 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
        h3 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
        c3 = torch.zeros(1, xs.size(0), 1024, device=xs.device)

        mel = []
        for x in torch.unbind(xs, dim=1):
            x = torch.cat((x, m), dim=1).unsqueeze(1)
            x1, (h1, c1) = self.lstm1(x, (h1, c1))
            x2, (h2, c2) = self.lstm2(x1, (h2, c2))
            x = x1 + x2
            x3, (h3, c3) = self.lstm3(x, (h3, c3))
            x = x + x3
            m = self.proj(x).squeeze(1)
            mel.append(m)
        return torch.stack(mel, dim=1)

In [89]:
enc = Encoder(256, True)
enc2 = Encoder(192, True)

In [90]:
x = enc(hubert.transpose(1,2))
y = enc2(timbre.transpose(1,2))
x.shape, y.shape

(torch.Size([4, 548, 512]), torch.Size([4, 548, 512]))

In [91]:
catt = torch.cat((x, y), dim=-1)
catt.shape

torch.Size([4, 548, 1024])

In [92]:
decoder = Decoder(args)
z = decoder(x, y)
z.shape

torch.Size([4, 548, 128])

In [93]:
a = torch.randn([4, 80, 100])
b = torch.randn([4, 80, 100])

In [96]:
c = a-b
print(c.shape)
c[0]

torch.Size([4, 80, 100])


tensor([[-0.7204, -0.7130, -3.0918,  ...,  2.1229, -1.9494, -0.7566],
        [-0.0798,  1.2675, -0.9218,  ...,  0.5414, -0.3596,  0.1896],
        [-0.0225,  0.3200,  0.6995,  ...,  0.0965,  0.7835,  0.3701],
        ...,
        [ 1.6791, -1.1309, -0.1175,  ..., -1.7433, -1.1525, -0.1156],
        [-1.7790,  0.8525, -0.5253,  ...,  0.7850, -0.0322,  0.5535],
        [-1.2482, -2.1930,  2.0543,  ..., -0.7643,  1.7816,  1.3197]])

In [97]:
c.abs().mean()

tensor(1.1425)

In [100]:
loss = F.l1_loss(a, b, reduction="none")
print(loss.shape)
loss[0]

torch.Size([4, 80, 100])


tensor([[0.7204, 0.7130, 3.0918,  ..., 2.1229, 1.9494, 0.7566],
        [0.0798, 1.2675, 0.9218,  ..., 0.5414, 0.3596, 0.1896],
        [0.0225, 0.3200, 0.6995,  ..., 0.0965, 0.7835, 0.3701],
        ...,
        [1.6791, 1.1309, 0.1175,  ..., 1.7433, 1.1525, 0.1156],
        [1.7790, 0.8525, 0.5253,  ..., 0.7850, 0.0322, 0.5535],
        [1.2482, 2.1930, 2.0543,  ..., 0.7643, 1.7816, 1.3197]])

In [112]:
torch.sum(loss, dim=(1,2)) / data['frame_lengths']

tensor([16.8403, 24.3411, 20.6968, 18.5207], dtype=torch.float64)

In [114]:
torch.sum(loss, dim=(1,2))

tensor([9228.4756, 9006.2051, 9065.1963, 9260.3457])

In [115]:
data['frame_lengths']

array([548, 370, 438, 500])