In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import os
import shutil
import os.path as osp

import torch
from torch import nn
import torch.nn.functional as F

from accelerate import Accelerator
from accelerate.utils import LoggerType

from transformers import AdamW
from transformers import AlbertConfig, AlbertModel
from accelerate import DistributedDataParallelKwargs

from model import MultiTaskModel
from dataloader import build_dataloader
from utils import length_to_mask, scan_checkpoint

from datasets import load_from_disk

In [3]:
import yaml
import pickle

config_path = "Configs/config.yml"
config = yaml.safe_load(open(config_path))

In [4]:
import pickle

with open(config['dataset_params']['token_maps'], 'rb') as handle:
    token_maps = pickle.load(handle)

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config['dataset_params']['tokenizer'])

In [6]:
criterion = nn.CrossEntropyLoss() # F0 loss (regression)

best_loss = float('inf')  # best test loss
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
loss_train_record = list([])
loss_test_record = list([])

num_steps = config['num_steps']
log_interval = config['log_interval']
save_interval = config['save_interval']

In [7]:
config['dataset_params']

{'tokenizer': 'mesolitica/PL-BERT-MS',
 'token_separator': ' ',
 'token_mask': 'M',
 'word_separator': 2,
 'token_maps': 'token_maps.pkl',
 'max_mel_length': 256,
 'word_mask_prob': 0.15,
 'phoneme_mask_prob': 0.1,
 'replace_prob': 0.2}

In [8]:
# batch_size = config["batch_size"]
# dataset = load_from_disk(config["data_folder"])
# train_loader = build_dataloader(dataset, 
#                                 batch_size=batch_size, 
#                                 num_workers=0, 
#                                 dataset_config=config['dataset_params'])
# albert_base_configuration = AlbertConfig(**config['model_params'])
# bert = AlbertModel(albert_base_configuration)
# bert = MultiTaskModel(bert, 
#                       num_vocab=1 + max([m['token'] for m in token_maps.values()]), 
#                       num_tokens=config['model_params']['vocab_size'],
#                       hidden_size=config['model_params']['hidden_size'])

In [9]:
# for _, batch in enumerate(train_loader):        

#     words, labels, phonemes, input_lengths, masked_indices = batch
#     phonemes = 
#     text_mask = length_to_mask(torch.Tensor(input_lengths))
    
#     break

#     # tokens_pred, words_pred = bert(phonemes, attention_mask=(~text_mask).int())

In [10]:
def train():
    
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    device = 'cuda'
    
    curr_steps = 0
    
    dataset = load_from_disk(config["data_folder"])

    log_dir = config['log_dir']
    if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
    shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
    
    batch_size = config["batch_size"]
    train_loader = build_dataloader(dataset, 
                                    batch_size=batch_size, 
                                    num_workers=0, 
                                    dataset_config=config['dataset_params'])

    albert_base_configuration = AlbertConfig(**config['model_params'])
    
    bert = AlbertModel(albert_base_configuration)
    bert = MultiTaskModel(bert, 
                          num_vocab=1 + max([m['token'] for m in token_maps.values()]), 
                          num_tokens=config['model_params']['vocab_size'],
                          hidden_size=config['model_params']['hidden_size'])
    
    load = True
    try:
        files = os.listdir(log_dir)
        ckpts = []
        for f in os.listdir(log_dir):
            if f.startswith("step_"): ckpts.append(f)

        iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
        iters = sorted(iters)[-1]
    except:
        iters = 0
        load = False
    
    optimizer = AdamW(bert.parameters(), lr=1e-4)
    
    accelerator = Accelerator(mixed_precision=config['mixed_precision'], split_batches=True, kwargs_handlers=[ddp_kwargs])
    
    if load:
        checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
        state_dict = checkpoint['net']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v

        bert.load_state_dict(new_state_dict, strict=False)
        
        accelerator.print('Checkpoint loaded.')
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    bert, optimizer, train_loader = accelerator.prepare(
        bert, optimizer, train_loader
    )

    accelerator.print('Start training...')

    running_loss = 0
    
    for _, batch in enumerate(train_loader):        
        curr_steps += 1
        
        words, labels, phonemes, input_lengths, masked_indices = batch
        phonemes = phonemes.to(device)
        text_mask = length_to_mask(torch.Tensor(input_lengths)).to(device)
        
        tokens_pred, words_pred = bert(phonemes, attention_mask=(~text_mask).int())
        
        loss_vocab = 0
        for _s2s_pred, _text_input, _text_length, _masked_indices in zip(words_pred, words, input_lengths, masked_indices):
            loss_vocab += criterion(_s2s_pred[:_text_length], 
                                        _text_input[:_text_length])
        loss_vocab /= words.size(0)
        
        loss_token = 0
        sizes = 1
        for _s2s_pred, _text_input, _text_length, _masked_indices in zip(tokens_pred, labels, input_lengths, masked_indices):
            if len(_masked_indices) > 0:
                _text_input = _text_input[:_text_length][_masked_indices]
                loss_tmp = criterion(_s2s_pred[:_text_length][_masked_indices], 
                                            _text_input[:_text_length]) 
                loss_token += loss_tmp
                sizes += 1
        loss_token /= sizes

        loss = loss_vocab + loss_token

        optimizer.zero_grad()
        accelerator.backward(loss)
        optimizer.step()

        running_loss += loss.item()

        iters = iters + 1
        if (iters+1)%log_interval == 0:
            accelerator.print ('Step [%d/%d], Loss: %.5f, Vocab Loss: %.5f, Token Loss: %.5f'
                    %(iters+1, num_steps, running_loss / log_interval, loss_vocab, loss_token))
            running_loss = 0
            
        if (iters+1)%save_interval == 0:
            accelerator.print('Saving..')

            state = {
                'net':  bert.state_dict(),
                'step': iters,
                'optimizer': optimizer.state_dict(),
            }

            accelerator.save(state, log_dir + '/step_' + str(iters + 1) + '.t7')

        if curr_steps > num_steps:
            return 

In [None]:
from accelerate import notebook_launcher
while True:
    notebook_launcher(train, args=(), num_processes=1, use_port=33389)

Launching training on one GPU.
177




Start training...
Step [10/1000000], Loss: 14.01790, Vocab Loss: 11.38148, Token Loss: 3.27761
Step [20/1000000], Loss: 14.13448, Vocab Loss: 11.06839, Token Loss: 2.96207
Step [30/1000000], Loss: 13.67338, Vocab Loss: 10.68721, Token Loss: 2.93207
Step [40/1000000], Loss: 13.28371, Vocab Loss: 9.95596, Token Loss: 2.90517
Step [50/1000000], Loss: 12.72050, Vocab Loss: 9.87544, Token Loss: 2.91731
Step [60/1000000], Loss: 12.33892, Vocab Loss: 9.28609, Token Loss: 2.82335
Step [70/1000000], Loss: 11.89261, Vocab Loss: 8.75583, Token Loss: 2.85019
Step [80/1000000], Loss: 11.65722, Vocab Loss: 8.88741, Token Loss: 2.87086
Step [90/1000000], Loss: 11.30377, Vocab Loss: 8.48093, Token Loss: 2.78896
Step [100/1000000], Loss: 11.11617, Vocab Loss: 8.32483, Token Loss: 2.79817
Step [110/1000000], Loss: 10.81688, Vocab Loss: 7.78804, Token Loss: 2.97620
Step [120/1000000], Loss: 10.88988, Vocab Loss: 7.78295, Token Loss: 2.83096
Step [130/1000000], Loss: 10.51173, Vocab Loss: 7.67728, Token L

Step [1080/1000000], Loss: 9.71558, Vocab Loss: 7.00485, Token Loss: 2.81119
Step [1090/1000000], Loss: 9.63988, Vocab Loss: 6.88409, Token Loss: 2.77340
Step [1100/1000000], Loss: 9.80692, Vocab Loss: 7.49187, Token Loss: 2.80837
Step [1110/1000000], Loss: 9.90156, Vocab Loss: 6.97016, Token Loss: 2.69760
Step [1120/1000000], Loss: 9.77295, Vocab Loss: 7.13341, Token Loss: 2.66087
Step [1130/1000000], Loss: 9.83911, Vocab Loss: 6.53768, Token Loss: 2.85607
Step [1140/1000000], Loss: 9.94759, Vocab Loss: 6.90035, Token Loss: 2.78868
Step [1150/1000000], Loss: 9.85416, Vocab Loss: 6.81202, Token Loss: 2.73243
Step [1160/1000000], Loss: 9.83770, Vocab Loss: 7.03450, Token Loss: 2.87112
Step [1170/1000000], Loss: 9.86228, Vocab Loss: 7.43615, Token Loss: 2.81876
Step [1180/1000000], Loss: 9.77952, Vocab Loss: 7.08521, Token Loss: 2.86816
Step [1190/1000000], Loss: 9.91942, Vocab Loss: 7.20185, Token Loss: 2.76659
Step [1200/1000000], Loss: 9.88025, Vocab Loss: 7.28602, Token Loss: 2.73952

Step [2150/1000000], Loss: 9.89629, Vocab Loss: 7.70423, Token Loss: 2.82994
Step [2160/1000000], Loss: 9.65982, Vocab Loss: 6.35969, Token Loss: 2.75238
Step [2170/1000000], Loss: 9.79693, Vocab Loss: 6.97296, Token Loss: 2.69216
Step [2180/1000000], Loss: 9.73816, Vocab Loss: 7.10723, Token Loss: 2.78294
Step [2190/1000000], Loss: 9.87784, Vocab Loss: 7.47158, Token Loss: 2.58033
Step [2200/1000000], Loss: 9.78895, Vocab Loss: 6.97951, Token Loss: 2.89181
Step [2210/1000000], Loss: 9.71250, Vocab Loss: 7.48609, Token Loss: 2.76563
Step [2220/1000000], Loss: 9.89419, Vocab Loss: 7.25998, Token Loss: 2.65018
Step [2230/1000000], Loss: 9.63709, Vocab Loss: 6.99374, Token Loss: 2.82866
Step [2240/1000000], Loss: 9.77043, Vocab Loss: 6.81307, Token Loss: 2.77799
Step [2250/1000000], Loss: 9.87271, Vocab Loss: 6.57449, Token Loss: 2.84043
Step [2260/1000000], Loss: 9.53252, Vocab Loss: 6.51200, Token Loss: 2.85270
Step [2270/1000000], Loss: 9.80228, Vocab Loss: 6.63694, Token Loss: 2.86935