In [26]:
import os
import pickle
import random
import subprocess
import torch.cuda
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from utils.earlystopping.protocols import EarlyStopping
from test_dataloder import *
import datetime
from utils.get_time import get_time
import gc
from tqdm import tqdm
from utils.warmup import *
import torch.nn.functional as F
from bartmodel import Bart
from transformers import get_linear_schedule_with_warmup

In [27]:
src_keys = ['strength', 'length', 'phrase']
tgt_keys = ['bar', 'pos', 'token', 'dur', 'phrase']

binary_dir = '/home/qihao/CS6207/binary'
words_dir = '/home/qihao/CS6207/binary/words'
hparams = {
    'batch_size': 8,
    'word_data_dir': '/home/qihao/CS6207/binary/words',
    'sentence_maxlen': 512,
    'hidden_size': 768,
    'n_layers': 6,
    'n_head': 8,
    'pretrain': '',
    'lr': 1.0e-5,
    'optimizer_adam_beta1': 0.9,
    'optimizer_adam_beta2': 0.98,
    'weight_decay': 0.01,
    'patience': 5,
    'warmup': 3000,
    'lr': 5.0e-5,
    'checkpoint_dir': '/home/qihao/CS6207/checkpoints',
    'drop_prob': 0.2,
    'total_epoch': 1000,
}

In [11]:
def set_seed(seed=1234):  # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # cuDNN在使用deterministic模式时（下面两行），可能会造成性能下降（取决于model）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [12]:
def xe_loss(outputs, targets):
    outputs = outputs.transpose(1, 2)
    return F.cross_entropy(outputs, targets, ignore_index=0, reduction='mean')

In [21]:
def train(train_loader, model, optimizer, scheduler, epoch, total_epoch):
    # define the format of tqdm
    with tqdm(total=len(train_loader), ncols=150, position=0, leave=True) as _tqdm:  # 总长度是data的长度
        _tqdm.set_description('training epoch: {}/{}'.format(epoch + 1, total_epoch))  # 设置前缀更新信息

        # Model Train
        model.train()
        running_loss = 0.0
        train_loss = []
        train_bar_loss = []
        train_pos_loss = []
        train_token_loss = []
        train_dur_loss = []
        train_phrase_loss = []

        for idx, data in enumerate(train_loader):
            # prompt_index = list(data[f'tgt_word'].numpy()).index(50268)
            enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
            dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            bar_out, pos_out, token_out, dur_out, phrase_out = model(enc_inputs, dec_inputs)
            
            # print(bar_out, bar_out.logit())
            
            bar_out = bar_out #.logit()
            tgt_bar = (data['tgt_bar'].to(device))[:, 1:]
            bar_loss = xe_loss(bar_out[:, :-1], tgt_bar)
            
            pos_out = pos_out #.logit()
            tgt_pos = (data['tgt_pos'].to(device))[:, 1:]
            pos_loss = xe_loss(pos_out[:, :-1], tgt_pos)
            
            token_out = token_out #.logit()
            tgt_token = (data['tgt_token'].to(device))[:, 1:]
            token_loss = xe_loss(token_out[:, :-1], tgt_token)
            
            dur_out = dur_out #.logit()
            tgt_dur = (data['tgt_dur'].to(device))[:, 1:]
            dur_loss = xe_loss(dur_out[:, :-1], tgt_dur)
            
            phrase_out = phrase_out #.logit()
            tgt_phrase = (data['tgt_phrase'].to(device))[:, 1:]
            phrase_loss = xe_loss(phrase_out[:, :-1], tgt_phrase)
            

            # 3) total loss
            total_loss = bar_loss + pos_loss + token_loss + dur_loss + phrase_loss
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss.append(total_loss.item())
            running_loss += total_loss.item()
            
            train_bar_loss.append(bar_loss.item())
            train_pos_loss.append(pos_loss.item())
            train_token_loss.append(token_loss.item())
            train_dur_loss.append(dur_loss.item())
            train_phrase_loss.append(phrase_loss.item())

            _tqdm.set_postfix(
                loss="{:.3f}, bar={:.3f}, pos={:.3f}, token={:.3f}, dur={:.3f}, phrase={:.3f}".format(total_loss,
                                                                                                      bar_loss, 
                                                                                                      pos_loss,
                                                                                                      token_loss,
                                                                                                      dur_loss,
                                                                                                      phrase_loss))
            
            _tqdm.update(2)

    train_loss_avg = np.mean(train_loss)
    train_bar_loss_avg = np.mean(train_bar_loss)
    train_pos_loss_avg = np.mean(train_pos_loss)
    train_token_loss_avg = np.mean(train_token_loss)
    train_dur_loss_avg = np.mean(train_dur_loss)
    train_phrase_loss_avg = np.mean(train_phrase_loss)
    
    return train_loss_avg, train_bar_loss_avg, train_pos_loss_avg, train_token_loss_avg, train_dur_loss_avg, train_phrase_loss_avg

In [22]:
def valid(valid_loader, model, epoch, total_epoch):
    # define the format of tqdm
    with tqdm(total=len(valid_loader), ncols=150) as _tqdm:  # 总长度是data的长度
        _tqdm.set_description('validation epoch: {}/{}'.format(epoch + 1, total_epoch))  # 设置前缀更新信息

        model.eval()  # switch to valid mode
        running_loss = 0.0
        val_loss = []
        val_bar_loss = []
        val_pos_loss = []
        val_token_loss = []
        val_dur_loss = []
        val_phrase_loss = []

        with torch.no_grad():
            for idx, data in enumerate((valid_loader)):
                try:
                    enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
                    dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}

                    bar_out, pos_out, token_out, dur_out, phrase_out = model(enc_inputs, dec_inputs)

                    bar_out = bar_out #.logit()
                    tgt_bar = (data['tgt_bar'].to(device))[:, 1:]
                    bar_loss = xe_loss(bar_out[:, :-1], tgt_bar)

                    pos_out = pos_out #.logit()
                    tgt_pos = (data['tgt_pos'].to(device))[:, 1:]
                    pos_loss = xe_loss(pos_out[:, :-1], tgt_pos)

                    token_out = token_out #.logit()
                    tgt_token = (data['tgt_token'].to(device))[:, 1:]
                    token_loss = xe_loss(token_out[:, :-1], tgt_token)

                    dur_out = dur_out #.logit()
                    tgt_dur = (data['tgt_dur'].to(device))[:, 1:]
                    dur_loss = xe_loss(dur_out[:, :-1], tgt_dur)

                    phrase_out = phrase_out# .logit()
                    tgt_phrase = (data['tgt_phrase'].to(device))[:, 1:]
                    phrase_loss = xe_loss(phrase_out[:, :-1], tgt_phrase)


                    # 3) total loss
                    total_loss = bar_loss + pos_loss + token_loss + dur_loss + phrase_loss
                    val_loss.append(total_loss.item())
                    running_loss += total_loss.item()

                    val_bar_loss.append(bar_loss.item())
                    val_pos_loss.append(pos_loss.item())
                    val_token_loss.append(token_loss.item())
                    val_dur_loss.append(dur_loss.item())
                    val_phrase_loss.append(phrase_loss.item())

                    _tqdm.set_postfix(
                        loss="{:.3f}, bar={:.3f}, pos={:.3f}, token={:.3f}, dur={:.3f}, phrase={:.3f}".format(total_loss,
                                                                                                              bar_loss, 
                                                                                                              pos_loss,
                                                                                                              token_loss,
                                                                                                              dur_loss,
                                                                                                              phrase_loss))

                    _tqdm.update(2)
                    
                except Exception as e:
                    print(data)
                    print("Bad Data Item!")
                    print(e)
                    break
            
    val_loss_avg = np.mean(val_loss)
    val_bar_loss_avg = np.mean(val_bar_loss)
    val_pos_loss_avg = np.mean(val_pos_loss)
    val_token_loss_avg = np.mean(val_token_loss)
    val_dur_loss_avg = np.mean(val_dur_loss)
    val_phrase_loss_avg = np.mean(val_phrase_loss)

    return val_loss_avg, val_bar_loss_avg, val_pos_loss_avg, val_token_loss_avg, val_dur_loss_avg, val_phrase_loss_avg

In [23]:
def train_l2m():
    ## train melody to lyric generation
    gc.collect()
    torch.cuda.empty_cache()
    
    global device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    # args
    set_seed()
    # set_hparams()
    event2word_dict, word2event_dict = pickle.load(open(f"{binary_dir}/music_dict.pkl", 'rb'))

    # tensorboard logger
    cur_time = get_time()
    # tensorboard_dir = hparams['tensorboard']
    # train_log_dir = f'{tensorboard_dir}/{cur_time}/train'
    # valid_log_dir = f'{tensorboard_dir}/{cur_time}/valid'
    # train_writer = SummaryWriter(log_dir=train_log_dir)
    # valid_writer = SummaryWriter(log_dir=valid_log_dir)

    # ------------
    # train
    # ------------
    # load data
    train_dataset = L2MDataset('train', event2word_dict, hparams, shuffle=True)
    valid_dataset = L2MDataset('valid', event2word_dict, hparams, shuffle=False)

    train_loader = build_dataloader(dataset=train_dataset, shuffle=True, batch_size=hparams['batch_size'], endless=False)
    val_loader = build_dataloader(dataset=valid_dataset, shuffle=False, batch_size=hparams['batch_size'], endless=False)
    
    print(len(train_loader))
    
    # print(f"foundation model pth: {hparams['custom_model_dir']}")
    
    def tensor_check_fn(key, param, input_param, error_msgs):
        if param.shape != input_param.shape:
            return False
        return True
    
    model = Bart(event2word_dict=event2word_dict, 
                 word2event_dict=word2event_dict, 
                 model_pth='',
                 hidden_size=hparams['hidden_size'], 
                 num_layers=hparams['n_layers'], 
                 num_heads=hparams['n_head'], 
                 dropout=hparams['drop_prob'],).to(device)
    
    pre_trained_path = hparams['pretrain']
    if pre_trained_path != '':
        current_model_dict = model.state_dict()
        loaded_state_dict = torch.load(pre_trained_path)
        new_state_dict={k:v if v.size()==current_model_dict[k].size() else current_model_dict[k] for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())}
        # model.load_state_dict(new_state_dict, strict=False)
        # model.load_state_dict(torch.load(pre_trained_path), strict=False, tensor_check_fn=tensor_check_fn)
        model.load_state_dict(new_state_dict, strict=False)
        print(">>> Load pretrained model successfully")
        
    ## warm up
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=hparams['lr'],
        betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
        weight_decay=hparams['weight_decay'])

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=hparams['warmup'], num_training_steps=-1
    )

    """
    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
    model.to(device)
    """

    # training conditions (for naming the ckpt)
    lr = hparams['lr']

    # early stop: initialize the early_stopping object
    # checkpointpath = f"{hparams['checkpoint_dir']}/Cond_{cond}_GPT2_{cur_time}_lr{lr}"
    checkpointpath = f"{hparams['checkpoint_dir']}/checkpoint_{cur_time}_lr_{lr}"
    if not os.path.exists(checkpointpath):
        os.mkdir(checkpointpath)
    early_stopping = EarlyStopping(patience=hparams['patience'], verbose=True,
                                   path=f"{checkpointpath}/early_stopping_checkpoint.pt")
    

    # -------- Train & Validation -------- #
    min_valid_running_loss = 1000000  # inf
    total_epoch = hparams['total_epoch']
    with tqdm(total=total_epoch) as _tqdm:
        for epoch in range(total_epoch):
            # Train
            train_running_loss, _, _, _, _, _ = train(train_loader, model, optimizer, scheduler, epoch, total_epoch)
            # train_writer.add_scalars("train_epoch_loss", {"running": train_running_loss, 'word': train_word_loss}, epoch)

            # validation  
            valid_running_loss, _, _, _, _, _ = valid(val_loader, model, epoch, total_epoch)
            # valid_writer.add_scalars("valid_epoch_loss", {"running": valid_running_loss, 'word': valid_word_loss}, epoch)

            # early stopping Check
            early_stopping(valid_running_loss, model, epoch)
            if early_stopping.early_stop == True:
                print("Validation Loss convergence， Train over")
                break

            # save the best checkpoint
            if valid_running_loss < min_valid_running_loss:
                min_valid_running_loss = valid_running_loss
                torch.save(model.state_dict(), f"{checkpointpath}/best.pt")
            print(f"Training Runinng Loss = {train_running_loss}, Validation Running Loss = {min_valid_running_loss}")  
            _tqdm.update(2)

In [24]:
train_l2m()

cuda:0
1545


training epoch: 1/1000:   0%|               | 6/1545 [00:00<02:03, 12.45it/s, loss=23.398, bar=5.553, pos=4.769, token=4.894, dur=5.664, phrase=2.518]

tensor([[ 1, 64, 66, 61, 64, 66, 64, 66, 61, 64, 69, 69, 69, 68, 66, 61, 64, 66,
         61, 64, 66, 64, 66, 61, 64, 59, 61, 63, 64, 66, 71, 68, 64, 64, 66, 61,
         64, 66, 64, 66, 61, 64, 64, 66, 66, 68, 68, 69, 66, 68, 69, 70, 71, 64,
         73, 71, 69, 64, 73, 73, 69, 64, 67, 73, 71, 69, 66, 71, 69, 66, 65, 71,
         69, 65, 64, 73, 73, 71, 69, 66, 74, 73, 64, 66,  2,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 52, 54, 56, 51, 52, 54, 56, 53, 54, 56, 57, 59, 56, 59, 54, 56, 57,
         59, 57, 59, 61, 63, 64, 66, 68, 68, 68, 64, 61, 64, 64, 66, 66, 63, 59,
         63, 64, 64, 61, 57, 61, 61, 54, 56, 57, 59, 54, 52, 54

training epoch: 1/1000:   1%|▏             | 14/1545 [00:00<01:08, 22.23it/s, loss=23.475, bar=5.504, pos=4.768, token=4.998, dur=5.579, phrase=2.626]

tensor([[ 1, 72, 72, 72, 72, 64, 66, 64, 61, 73, 73, 73, 73, 73, 73, 64, 66, 64,
         61, 71, 68, 68, 68, 68, 64, 68, 64, 68, 64, 68, 64, 68, 61, 68, 66, 64,
         64, 66, 64, 66, 64, 63, 61, 68, 68, 68, 68, 64, 66, 68, 66, 64, 68, 64,
         68, 61, 68, 66, 64, 64, 66, 64, 63, 61, 60, 60, 63, 61, 68, 67, 66, 66,
         64, 63, 61,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 66, 68, 64, 61, 64, 66, 66, 66, 68, 64, 61, 64, 68, 68, 66, 68, 66,
         64, 64, 71, 71, 59, 68, 66, 64, 59, 68, 66, 66, 64, 66, 71, 73, 71, 68,
         66, 68, 71, 73, 71, 68, 64, 66, 71, 73, 71, 66, 63, 64, 63, 64, 66, 68,
         66, 68, 69, 71, 71, 71, 73, 66, 68, 64, 71, 71, 71, 66, 68, 64, 68, 68,
         68, 66, 64, 66, 66, 66, 68, 69, 68

training epoch: 1/1000:   1%|▏             | 22/1545 [00:01<00:51, 29.40it/s, loss=23.227, bar=5.481, pos=4.635, token=4.945, dur=5.700, phrase=2.466]

tensor([[ 1, 81, 81, 81, 81, 81, 81, 80, 79, 79, 79, 79, 81, 83, 81, 81, 78, 83,
         83, 83, 83, 81, 83, 80, 76, 76, 78, 81, 78, 83, 78, 81, 78, 85, 83, 81,
         83, 85, 81, 78, 78, 81, 78, 79, 80, 82, 82, 82, 82, 82, 82, 85, 87, 82,
         85, 87, 87, 83, 88, 87, 85, 82, 82, 82, 82, 82, 85, 87, 82, 85, 87, 87,
         83, 88, 87, 85, 81, 81, 81, 81, 81, 81, 80, 79, 79, 79, 79, 81, 83, 81,
         81, 78, 78, 83, 83, 83, 83, 81, 83, 80, 76, 76, 78, 81, 78, 83, 78, 81,
         78, 85, 83, 81, 83, 85, 81, 78, 78, 81, 78, 79,  2],
        [ 1, 78, 78, 78, 78, 78, 78, 78, 78, 73, 75, 76, 78, 80, 87, 85, 80, 78,
         76, 78, 80, 76, 73, 76, 76, 75, 76, 83, 83, 80, 68, 73, 75, 76, 78, 80,
         87, 85, 79, 78, 76, 78, 80, 76, 73, 68, 73, 76, 78, 79, 78, 76, 73, 81,
         83, 85, 86, 85, 83, 81, 83, 85, 76, 73, 76, 76, 81, 83, 85, 86, 85, 83,
         81, 88, 88, 85, 85, 82, 84, 85, 87, 88, 87, 87, 85, 87, 75, 75, 75, 72,
         75, 75, 78, 75, 79, 75, 80, 82, 78, 80

training epoch: 1/1000:   2%|▎             | 30/1545 [00:01<00:45, 33.59it/s, loss=23.121, bar=5.527, pos=4.663, token=5.008, dur=5.609, phrase=2.314]

tensor([[ 1, 61, 68, 66, 68, 64, 63, 61, 68, 68, 61, 64, 66, 68, 66, 68, 68, 61,
         64, 64, 66, 68, 68, 68, 61, 64, 63, 61, 59, 61, 59, 56, 59, 61, 64, 64,
         63, 61, 63, 61, 59, 56, 68, 68, 61, 64, 66, 68, 66, 68, 68, 68, 61, 64,
         66, 68, 68, 68, 61, 64, 63, 63, 63, 61, 59, 61, 59, 56, 59, 61, 61, 61,
         68, 66, 68, 64, 63, 61, 68, 68, 61, 64, 66, 68, 66, 68, 68, 61, 64, 64,
         66, 68, 68, 68, 61, 64, 63, 63, 63, 61, 59, 61, 59, 56, 59, 61, 64, 64,
         63, 61, 63, 61, 59, 56, 68, 68, 61, 64, 66, 68, 66, 68, 68, 68, 61, 64,
         66, 68, 68, 61, 64, 63, 63, 63, 61, 59, 61, 59, 56, 59, 61, 68, 68, 68,
         61, 64, 66, 68, 68, 68, 61, 64, 63, 63, 63, 61, 59, 61, 59, 56, 59, 61,
          2],
        [ 1, 54, 54, 54, 57, 57, 57, 59, 54, 54, 54, 54, 57, 57, 57, 59, 54, 57,
         57, 57, 61, 59, 57, 59, 54, 54, 54, 57, 57, 59, 54, 54, 54, 57, 57, 59,
         54, 54, 54, 54, 57, 57, 57, 59, 54, 64, 59, 64, 59, 66, 63, 63, 64, 59,
         64, 5

training epoch: 1/1000:   2%|▎             | 38/1545 [00:01<00:43, 34.43it/s, loss=23.164, bar=5.596, pos=4.758, token=5.105, dur=5.519, phrase=2.186]

tensor([[ 1, 80, 78, 80, 80, 80, 78, 80, 80, 78, 80, 80, 78, 80, 80, 80, 78, 80,
         80, 78, 80, 73, 80, 78, 76, 78, 80, 78, 76, 78, 80, 76, 75, 73, 73, 80,
         78, 76, 78, 80, 78, 76, 78, 80, 76, 75, 73, 76, 76, 76, 76, 76, 76, 76,
         76, 76, 75, 73, 73, 76, 76, 76, 75, 73, 76, 75, 73, 75, 76, 73, 75, 76,
         73, 75, 76, 76, 76, 75, 76, 75, 73, 76, 75, 73, 76, 75, 73, 71, 73, 73,
         73, 75, 76, 76, 76, 75, 73, 76, 75, 73, 75, 76, 75, 73, 75, 76, 75, 73,
         75, 76, 75, 73, 75, 76, 75, 73, 75, 76, 75, 73, 76, 75, 73, 76, 75, 73,
         76, 75, 73, 75, 76, 75, 73, 71, 73, 73, 75, 76, 75, 73, 76, 78, 76, 75,
         80, 81, 80, 81, 80, 71, 71, 73, 73, 73, 73, 71, 73, 73, 73, 73, 71, 73,
         73, 73, 73, 73, 73, 85, 85, 71, 71, 73, 73, 73, 73, 71, 73, 73, 73, 73,
         71, 73, 73, 73, 73, 73, 73, 85, 85, 73, 80, 78, 76, 78, 80, 78, 76, 78,
         80, 76, 75, 73, 71, 73, 73, 73, 73, 80, 78, 76, 78, 80, 78, 76, 78, 80,
         76, 75, 73, 71, 73,

training epoch: 1/1000:   2%|▎             | 38/1545 [00:01<01:04, 23.41it/s, loss=23.164, bar=5.596, pos=4.758, token=5.105, dur=5.519, phrase=2.186]
  0%|                                                  | 0/1000 [00:01<?, ?it/s]


KeyboardInterrupt: 