In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler  import OneCycleLR

from src.configs import FastSpeechConfig
from src.configs import TrainConfig

from src.util import BufferDataset
from src.util import download_buffer
from src.util import collate_fn_tensor

from src.wandb_writer import WanDBWriter
from src.model import FastSpeech
from src.loss import FastSpeechLoss
from src.train import train

In [None]:
model_config = FastSpeechConfig()
train_config = TrainConfig()

In [None]:
download_buffer()
buffer = torch.load('saved_buffer.pkl')
for buf in buffer:
    buf['energy'] /= 488
    buf['pitch'] /= 862

dataset = BufferDataset(buffer)

training_loader = DataLoader(
    dataset,
    batch_size=train_config.batch_expand_size * train_config.batch_size,
    shuffle=True,
    collate_fn=collate_fn_tensor,
    drop_last=True,
    num_workers=0
)

In [None]:
model = FastSpeech(model_config)
model = model.to(train_config.device)

fastspeech_loss = FastSpeechLoss()
current_step = 0

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=train_config.learning_rate,
    betas=(0.9, 0.98),
    eps=1e-9)

scheduler = OneCycleLR(optimizer, **{
    "steps_per_epoch": len(training_loader) * train_config.batch_expand_size,
    "epochs": train_config.epochs,
    "anneal_strategy": "cos",
    "max_lr": train_config.learning_rate,
    "pct_start": 0.1
})
logger = WanDBWriter(train_config)

In [None]:
train(
    model=model,
    fastspeech_loss=fastspeech_loss,
    optimizer=optimizer,
    scheduler=scheduler,
    logger=logger,
    training_loader=training_loader,
    train_config=train_config
)