In [1]:
import os
import os.path as osp
import glob
import re
import sys
import yaml
import shutil
import numpy as np
import torch
import click
from socket import gethostname
import warnings
warnings.simplefilter('ignore')

# load packages
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa

from models import *
from meldataset import build_dataloader
from utils import *
from losses import *
from optimizers import build_optimizer
import time

from accelerate import Accelerator
from accelerate.utils import LoggerType
from accelerate import DistributedDataParallelKwargs

from torch.utils.tensorboard import SummaryWriter

import logging
from accelerate.logging import get_logger
logger = get_logger(__name__, log_level="DEBUG")

class MEM_PROBE(object):
    def __init__(self):
        self.last_mem = torch.cuda.memory_allocated()

    def __call__(self):
        current_mem = torch.cuda.memory_allocated()
        mem_delta = current_mem - self.last_mem
        self.last_mem = current_mem
        print(f"MEM:{sys._getframe(1).f_code.co_filename}:{sys._getframe(1).f_lineno} {mem_delta/(1024*1024*1024):.2f} : {current_mem/(1024*1024*1024):.2f} GB")
        return None


In [2]:
with open("/rhome/eingerman/Projects/DeepLearning/TTS/StyleTTS2/Models/tensors.pt", "rb") as f:
    (y_rec, wav)=torch.load(f)

In [3]:
# y_rec = y_rec[:,:,:8000]
# wav = wav[:,:8000]

print(y_rec.shape, wav.shape)

torch.Size([4, 1, 46800]) torch.Size([4, 46800])


In [9]:
config = yaml.safe_load(open("Configs/config_libritts_espeak.yml"))

device = "cuda:0"
model_params = recursive_munch(config['model_params'])
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)
model = build_model(model_params, text_aligner, pitch_extractor, plbert)

scheduler_params = {
    "max_lr": float(config['optimizer_params'].get('lr', 1e-4)),
    "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
    "epochs": 200,
    "steps_per_epoch": 1000,
}
optimizer = build_optimizer({key: model[key].parameters() for key in model},
                                scheduler_params_dict= {key: scheduler_params.copy() for key in model},
                            lr=float(config['optimizer_params'].get('lr', 1e-4)))


gl = GeneratorLoss(model.mpd, model.msd).to(device) 
dl = DiscriminatorLoss(model.mpd, model.msd).to(device)

sr = config['preprocess_params'].get('sr', 24000)
wl = WavLMLoss(model_params.slm.model, 
                model.wd, 
                sr, 
                model_params.slm.sr).to(device)

Some weights of the model checkpoint at microsoft/wavlm-base-plus were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at microsoft/wavlm-base-plus and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictio

In [5]:
# for k, v in optimizer.optimizers.items():
#     print(k, v)

bert AdamW (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.8500000000061685, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-09
    foreach: None
    fused: None
    initial_lr: 0.0001
    lr: 0.0001
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 0.0001
    weight_decay: 0.0001
)
bert_encoder AdamW (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.8500000000061685, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-09
    foreach: None
    fused: None
    initial_lr: 0.0001
    lr: 0.0001
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 0.0001
    weight_decay: 0.0001
)
predictor AdamW (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.8500000000061685, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-09
    foreach: None
    fused: None
    initial_lr: 0.0001
    lr: 0.0001
    max_lr: 0.0001
    max_mom

In [8]:
optimizer.zero_grad()
w = wav.detach().unsqueeze(1).float()
y = y_rec.detach()
l = w.shape[2]
idx = tuple(range(2048, l, 2048))

memory_probe = MEM_PROBE()
for w_chunk, y_chunk in zip(torch.tensor_split(w, idx, dim=2), torch.tensor_split(y, idx, dim=2)):
    d_loss = dl(w_chunk, y_chunk)
    d_loss.backward()
memory_probe()


# memory_probe()

MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.31 : 0.63 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.15 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 -0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78 GB
MEM:/tmp/ipykernel_3837622/2444908502.py:10 0.00 : 0.78

In [33]:
torch.cuda.empty_cache()

In [24]:
# torch.chunk(w, chunks=8, dim=2)
l = w.shape[2]
idx = tuple(range(2048, l, 2048))
[t.shape for t in torch.tensor_split(w, idx, dim=2)]

[torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 2048]),
 torch.Size([4, 1, 1744])]

In [22]:
idx


(0,
 2048,
 4096,
 6144,
 8192,
 10240,
 12288,
 14336,
 16384,
 18432,
 20480,
 22528,
 24576,
 26624,
 28672,
 30720,
 32768,
 34816,
 36864,
 38912,
 40960,
 43008,
 45056)

In [15]:
total=0
for name, param in model.mpd.named_parameters():
    if param.requires_grad:
        total += param.numel()
        print(f"{name} : {param.numel()}")
print(f"Total: {total}")

discriminators.0.convs.0.bias : 32
discriminators.0.convs.0.weight_g : 32
discriminators.0.convs.0.weight_v : 160
discriminators.0.convs.1.bias : 128
discriminators.0.convs.1.weight_g : 128
discriminators.0.convs.1.weight_v : 20480
discriminators.0.convs.2.bias : 512
discriminators.0.convs.2.weight_g : 512
discriminators.0.convs.2.weight_v : 327680
discriminators.0.convs.3.bias : 1024
discriminators.0.convs.3.weight_g : 1024
discriminators.0.convs.3.weight_v : 2621440
discriminators.0.convs.4.bias : 1024
discriminators.0.convs.4.weight_g : 1024
discriminators.0.convs.4.weight_v : 5242880
discriminators.0.conv_post.bias : 1
discriminators.0.conv_post.weight_g : 1
discriminators.0.conv_post.weight_v : 3072
discriminators.1.convs.0.bias : 32
discriminators.1.convs.0.weight_g : 32
discriminators.1.convs.0.weight_v : 160
discriminators.1.convs.1.bias : 128
discriminators.1.convs.1.weight_g : 128
discriminators.1.convs.1.weight_v : 20480
discriminators.1.convs.2.bias : 512
discriminators.1.c

In [14]:
type(param.grad)

NoneType

In [11]:
for w_chunk, y_chunk in zip(torch.tensor_split(wav.detach(), idx, dim=1), torch.tensor_split(y_rec, idx, dim=2)):
    loss_slm = wl(w_chunk.detach(), y_chunk)
    # accelerator.backward(loss_slm, retain_graph=True)  #TODO: check if this is correct
memory_probe()

MEM:/tmp/ipykernel_3837622/1564091087.py:4 0.41 : 0.87 GB


In [20]:
t=tuple(range(2048, 4100, 2048))
(t[:-1] + (4100,))

(2048, 4100)

In [18]:
tuple(4100)

TypeError: 'int' object is not iterable