In [9]:
import os
import pickle
import random
import subprocess
import torch.cuda
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from utils.earlystopping.protocols import EarlyStopping
from test_dataloder import *
import datetime
from utils.get_time import get_time
import gc
from tqdm import tqdm
from utils.warmup import *
import torch.nn.functional as F
from transformermodel import *
from transformers import get_linear_schedule_with_warmup

In [10]:
src_keys = ['strength', 'length', 'phrase']
tgt_keys = ['bar', 'pos', 'token', 'dur', 'phrase']

binary_dir = '/home/qihao/CS6207/binary'
words_dir = '/home/qihao/CS6207/binary/words'
hparams = {
    'batch_size': 2,
    'word_data_dir': '/home/qihao/CS6207/binary/words',
    'sentence_maxlen': 512,
    'hidden_size': 256,
    'n_layers': 6,
    'n_head': 8,
    'pretrain': '',
    'lr': 5.0e-5,
    'optimizer_adam_beta1': 0.9,
    'optimizer_adam_beta2': 0.98,
    'weight_decay': 0.001,
    'patience': 5,
    'warmup': 2500,
    'lr': 5.0e-5,
    'checkpoint_dir': '/home/qihao/CS6207/checkpoints',
    'drop_prob': 0.2,
    'total_epoch': 1000,
    'infer_batch_size': 1,
    'temperature': 1.3,
    'topk': 5,
    'prompt_step': 1,
    'infer_max_step': 1024,
    'output_dir': "/home/qihao/CS6207/output_melody",
    'num_heads': 4,
    'enc_layers': 4, 
    'dec_layers': 4, 
    'enc_ffn_kernel_size': 1,
    'dec_ffn_kernel_size': 1,
}

In [11]:
def set_seed(seed=1234):  # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # cuDNN在使用deterministic模式时（下面两行），可能会造成性能下降（取决于model）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [12]:
def xe_loss(outputs, targets):
    outputs = outputs.transpose(1, 2)
    return F.cross_entropy(outputs, targets, ignore_index=0, reduction='mean')

In [13]:
def train(train_loader, model, optimizer, scheduler, epoch, total_epoch):
    # define the format of tqdm
    with tqdm(total=len(train_loader), ncols=150, position=0, leave=True) as _tqdm:  # 总长度是data的长度
        _tqdm.set_description('training epoch: {}/{}'.format(epoch + 1, total_epoch))  # 设置前缀更新信息

        # Model Train
        model.train()
        running_loss = 0.0
        train_loss = []
        train_bar_loss = []
        train_pos_loss = []
        train_token_loss = []
        train_dur_loss = []
        train_phrase_loss = []

        for idx, data in enumerate(train_loader):
            # prompt_index = list(data[f'tgt_word'].numpy()).index(50268)
            enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
            dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            bar_out, pos_out, token_out, dur_out, phrase_out = model(enc_inputs, dec_inputs)
            
            # print(bar_out, bar_out.logit())
            
            bar_out = bar_out #.logit()
            tgt_bar = (data['tgt_bar'].to(device))[:, 1:]
            bar_loss = xe_loss(bar_out[:, :-1], tgt_bar)
            
            pos_out = pos_out #.logit()
            tgt_pos = (data['tgt_pos'].to(device))[:, 1:]
            pos_loss = xe_loss(pos_out[:, :-1], tgt_pos)
            
            token_out = token_out #.logit()
            tgt_token = (data['tgt_token'].to(device))[:, 1:]
            token_loss = xe_loss(token_out[:, :-1], tgt_token)
            
            dur_out = dur_out #.logit()
            tgt_dur = (data['tgt_dur'].to(device))[:, 1:]
            dur_loss = xe_loss(dur_out[:, :-1], tgt_dur)
            
            phrase_out = phrase_out #.logit()
            tgt_phrase = (data['tgt_phrase'].to(device))[:, 1:]
            phrase_loss = xe_loss(phrase_out[:, :-1], tgt_phrase)
            

            # 3) total loss
            total_loss = bar_loss + pos_loss + token_loss + dur_loss + phrase_loss
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss.append(total_loss.item())
            running_loss += total_loss.item()
            
            train_bar_loss.append(bar_loss.item())
            train_pos_loss.append(pos_loss.item())
            train_token_loss.append(token_loss.item())
            train_dur_loss.append(dur_loss.item())
            train_phrase_loss.append(phrase_loss.item())

            _tqdm.set_postfix(
                loss="{:.3f}, bar={:.3f}, pos={:.3f}, token={:.3f}, dur={:.3f}, phrase={:.3f}".format(total_loss,
                                                                                                      bar_loss, 
                                                                                                      pos_loss,
                                                                                                      token_loss,
                                                                                                      dur_loss,
                                                                                                      phrase_loss))
            
            _tqdm.update(2)

    train_loss_avg = np.mean(train_loss)
    train_bar_loss_avg = np.mean(train_bar_loss)
    train_pos_loss_avg = np.mean(train_pos_loss)
    train_token_loss_avg = np.mean(train_token_loss)
    train_dur_loss_avg = np.mean(train_dur_loss)
    train_phrase_loss_avg = np.mean(train_phrase_loss)
    
    return train_loss_avg, train_bar_loss_avg, train_pos_loss_avg, train_token_loss_avg, train_dur_loss_avg, train_phrase_loss_avg

In [14]:
def valid(valid_loader, model, epoch, total_epoch):
    # define the format of tqdm
    with tqdm(total=len(valid_loader), ncols=150) as _tqdm:  # 总长度是data的长度
        _tqdm.set_description('validation epoch: {}/{}'.format(epoch + 1, total_epoch))  # 设置前缀更新信息

        model.eval()  # switch to valid mode
        running_loss = 0.0
        val_loss = []
        val_bar_loss = []
        val_pos_loss = []
        val_token_loss = []
        val_dur_loss = []
        val_phrase_loss = []

        with torch.no_grad():
            for idx, data in enumerate((valid_loader)):
                try:
                    enc_inputs = {k: data[f'src_{k}'].to(device) for k in src_keys}
                    dec_inputs = {k: data[f'tgt_{k}'].to(device) for k in tgt_keys}

                    bar_out, pos_out, token_out, dur_out, phrase_out = model(enc_inputs, dec_inputs)

                    bar_out = bar_out #.logit()
                    tgt_bar = (data['tgt_bar'].to(device))[:, 1:]
                    bar_loss = xe_loss(bar_out[:, :-1], tgt_bar)

                    pos_out = pos_out #.logit()
                    tgt_pos = (data['tgt_pos'].to(device))[:, 1:]
                    pos_loss = xe_loss(pos_out[:, :-1], tgt_pos)

                    token_out = token_out #.logit()
                    tgt_token = (data['tgt_token'].to(device))[:, 1:]
                    token_loss = xe_loss(token_out[:, :-1], tgt_token)

                    dur_out = dur_out #.logit()
                    tgt_dur = (data['tgt_dur'].to(device))[:, 1:]
                    dur_loss = xe_loss(dur_out[:, :-1], tgt_dur)

                    phrase_out = phrase_out# .logit()
                    tgt_phrase = (data['tgt_phrase'].to(device))[:, 1:]
                    phrase_loss = xe_loss(phrase_out[:, :-1], tgt_phrase)


                    # 3) total loss
                    total_loss = bar_loss + pos_loss + token_loss + dur_loss + phrase_loss
                    val_loss.append(total_loss.item())
                    running_loss += total_loss.item()

                    val_bar_loss.append(bar_loss.item())
                    val_pos_loss.append(pos_loss.item())
                    val_token_loss.append(token_loss.item())
                    val_dur_loss.append(dur_loss.item())
                    val_phrase_loss.append(phrase_loss.item())

                    _tqdm.set_postfix(
                        loss="{:.3f}, bar={:.3f}, pos={:.3f}, token={:.3f}, dur={:.3f}, phrase={:.3f}".format(total_loss,
                                                                                                              bar_loss, 
                                                                                                              pos_loss,
                                                                                                              token_loss,
                                                                                                              dur_loss,
                                                                                                              phrase_loss))

                    _tqdm.update(2)
                    
                except Exception as e:
                    print(data)
                    print("Bad Data Item!")
                    print(e)
                    break
            
    val_loss_avg = np.mean(val_loss)
    val_bar_loss_avg = np.mean(val_bar_loss)
    val_pos_loss_avg = np.mean(val_pos_loss)
    val_token_loss_avg = np.mean(val_token_loss)
    val_dur_loss_avg = np.mean(val_dur_loss)
    val_phrase_loss_avg = np.mean(val_phrase_loss)

    return val_loss_avg, val_bar_loss_avg, val_pos_loss_avg, val_token_loss_avg, val_dur_loss_avg, val_phrase_loss_avg

In [15]:
def train_l2m():
    ## train melody to lyric generation
    gc.collect()
    torch.cuda.empty_cache()
    
    global device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    # args
    set_seed()
    # set_hparams()
    event2word_dict, word2event_dict = pickle.load(open(f"{binary_dir}/music_dict.pkl", 'rb'))

    # tensorboard logger
    cur_time = get_time()
    # tensorboard_dir = hparams['tensorboard']
    # train_log_dir = f'{tensorboard_dir}/{cur_time}/train'
    # valid_log_dir = f'{tensorboard_dir}/{cur_time}/valid'
    # train_writer = SummaryWriter(log_dir=train_log_dir)
    # valid_writer = SummaryWriter(log_dir=valid_log_dir)

    # ------------
    # train
    # ------------
    # load data
    train_dataset = L2MDataset('train', event2word_dict, hparams, shuffle=True)
    valid_dataset = L2MDataset('valid', event2word_dict, hparams, shuffle=False)

    train_loader = build_dataloader(dataset=train_dataset, shuffle=True, batch_size=hparams['batch_size'], endless=False)
    val_loader = build_dataloader(dataset=valid_dataset, shuffle=False, batch_size=hparams['batch_size'], endless=False)
    
    print(len(train_loader))
    
    # print(f"foundation model pth: {hparams['custom_model_dir']}")
    
    def tensor_check_fn(key, param, input_param, error_msgs):
        if param.shape != input_param.shape:
            return False
        return True
    
    model = MusicTransformer(event2word_dict=event2word_dict, 
                             word2event_dict=word2event_dict, 
                             hidden_size=hparams['hidden_size'], 
                             num_heads=hparams['num_heads'],
                             enc_layers=hparams['enc_layers'], 
                             dec_layers=hparams['dec_layers'], 
                             dropout=hparams['drop_prob'], 
                             enc_ffn_kernel_size=hparams['enc_ffn_kernel_size'],
                             dec_ffn_kernel_size=hparams['dec_ffn_kernel_size'],
                            ).to(device)
    
    pre_trained_path = hparams['pretrain']
    if pre_trained_path != '':
        current_model_dict = model.state_dict()
        loaded_state_dict = torch.load(pre_trained_path)
        new_state_dict={k:v if v.size()==current_model_dict[k].size() else current_model_dict[k] for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())}
        # model.load_state_dict(new_state_dict, strict=False)
        # model.load_state_dict(torch.load(pre_trained_path), strict=False, tensor_check_fn=tensor_check_fn)
        model.load_state_dict(new_state_dict, strict=False)
        print(">>> Load pretrained model successfully")
        
    ## warm up
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=hparams['lr'],
        betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
        weight_decay=hparams['weight_decay'])

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=hparams['warmup'], num_training_steps=-1
    )

    """
    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
    model.to(device)
    """

    # training conditions (for naming the ckpt)
    lr = hparams['lr']

    # early stop: initialize the early_stopping object
    # checkpointpath = f"{hparams['checkpoint_dir']}/Cond_{cond}_GPT2_{cur_time}_lr{lr}"
    checkpointpath = f"{hparams['checkpoint_dir']}/checkpoint_{cur_time}_lr_{lr}"
    if not os.path.exists(checkpointpath):
        os.mkdir(checkpointpath)
    early_stopping = EarlyStopping(patience=hparams['patience'], verbose=True,
                                   path=f"{checkpointpath}/early_stopping_checkpoint.pt")
    

    # -------- Train & Validation -------- #
    min_valid_running_loss = 1000000  # inf
    total_epoch = hparams['total_epoch']
    with tqdm(total=total_epoch) as _tqdm:
        for epoch in range(total_epoch):
            # Train
            train_running_loss, _, _, _, _, _ = train(train_loader, model, optimizer, scheduler, epoch, total_epoch)
            # train_writer.add_scalars("train_epoch_loss", {"running": train_running_loss, 'word': train_word_loss}, epoch)

            # validation  
            valid_running_loss, _, _, _, _, _ = valid(val_loader, model, epoch, total_epoch)
            # valid_writer.add_scalars("valid_epoch_loss", {"running": valid_running_loss, 'word': valid_word_loss}, epoch)

            # early stopping Check
            early_stopping(valid_running_loss, model, epoch)
            if early_stopping.early_stop == True:
                print("Validation Loss convergence， Train over")
                break

            # save the best checkpoint
            if valid_running_loss < min_valid_running_loss:
                min_valid_running_loss = valid_running_loss
                torch.save(model.state_dict(), f"{checkpointpath}/best.pt")
            print(f"Training Runinng Loss = {train_running_loss}, Validation Running Loss = {min_valid_running_loss}")  
            _tqdm.update(2)

In [16]:
train_l2m()

cuda:0
1545


training epoch: 1/1000: : 3090it [00:47, 64.75it/s, loss=9.764, bar=2.347, pos=2.423, token=2.853, dur=1.799, phrase=0.342]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 1/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 1/1000:   0%|                      | 0/193 [00:00<?, ?it/s, loss=11.550, bar=3.663, pos=2.646, token=3.368, dur=1.538, phrase=0.335][A
validation epoch: 1/1000:   1%|▏             | 2/193 [00:00<00:38,  5.00it/s, loss=11.550, bar=3.663, pos=2.646, token=3.368, dur=1.538, phrase=0.335][A
validation epoch: 1/1000:   1%|▏              | 2/193 [00:00<00:38,  5.00it/s, loss=8.803, bar=2.452, pos=1.870, token=2.690, dur=1.375, phrase=0.416][A
validation epoch: 1/1000:   2%|▎              | 4/193 [00:00<00:37,  5.00it/s,

Validation loss decreased (inf --> 10.224549).  Saving model ...
Training Runinng Loss = 13.8567897037395, Validation Running Loss = 10.224548517731188


training epoch: 2/1000: : 3090it [00:43, 71.26it/s, loss=7.662, bar=1.454, pos=1.701, token=2.408, dur=1.795, phrase=0.304]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 2/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 2/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 2/1000:   1%|▏              | 2/193 [00:00<00:33,  5.69it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 2/1000:   1%|▏              | 2/193 [00:00<00:33,  5.69it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 2/1000:   2%|▎              | 4/193 [00:00<00:33,  5.69it/s,

Validation loss decreased (10.224549 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.789494558599774, Validation Running Loss = 7.776168944304471


training epoch: 3/1000: : 3090it [00:44, 68.69it/s, loss=7.487, bar=1.399, pos=1.619, token=2.395, dur=1.786, phrase=0.289]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 3/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 3/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 3/1000:   1%|▏              | 2/193 [00:00<00:33,  5.63it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 3/1000:   1%|▏              | 2/193 [00:00<00:33,  5.63it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 3/1000:   2%|▎              | 4/193 [00:00<00:33,  5.63it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.11411827112093, Validation Running Loss = 7.776168944304471


training epoch: 4/1000: : 3090it [00:45, 68.10it/s, loss=7.586, bar=1.442, pos=1.655, token=2.376, dur=1.819, phrase=0.296]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 4/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 4/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 4/1000:   1%|▏              | 2/193 [00:00<00:31,  6.03it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 4/1000:   1%|▏              | 2/193 [00:00<00:31,  6.03it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 4/1000:   2%|▎              | 4/193 [00:00<00:31,  6.03it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.112381440999053, Validation Running Loss = 7.776168944304471


training epoch: 5/1000: : 3090it [00:43, 71.86it/s, loss=7.559, bar=1.427, pos=1.616, token=2.405, dur=1.805, phrase=0.305]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 5/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 5/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 5/1000:   1%|▏              | 2/193 [00:00<00:27,  6.86it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 5/1000:   1%|▏              | 2/193 [00:00<00:27,  6.86it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 5/1000:   2%|▎              | 4/193 [00:00<00:27,  6.86it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.110670210705607, Validation Running Loss = 7.776168944304471


training epoch: 6/1000: : 3090it [00:44, 69.56it/s, loss=7.495, bar=1.424, pos=1.615, token=2.398, dur=1.759, phrase=0.299]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 6/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 6/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 6/1000:   1%|▏              | 2/193 [00:00<00:31,  6.02it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 6/1000:   1%|▏              | 2/193 [00:00<00:31,  6.02it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 6/1000:   2%|▎              | 4/193 [00:00<00:31,  6.02it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.11290775039821, Validation Running Loss = 7.776168944304471


training epoch: 7/1000: : 3090it [00:47, 64.99it/s, loss=7.509, bar=1.471, pos=1.607, token=2.352, dur=1.792, phrase=0.286]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 7/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 7/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 7/1000:   1%|▏              | 2/193 [00:00<00:30,  6.21it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 7/1000:   1%|▏              | 2/193 [00:00<00:30,  6.21it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 7/1000:   2%|▎              | 4/193 [00:00<00:30,  6.21it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.112731315711555, Validation Running Loss = 7.776168944304471


training epoch: 8/1000: : 3090it [00:44, 69.75it/s, loss=7.536, bar=1.458, pos=1.656, token=2.350, dur=1.767, phrase=0.305]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 8/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 8/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 8/1000:   1%|▏              | 2/193 [00:00<00:30,  6.32it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 8/1000:   1%|▏              | 2/193 [00:00<00:30,  6.32it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 8/1000:   2%|▎              | 4/193 [00:00<00:29,  6.32it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.113987973901446, Validation Running Loss = 7.776168944304471


training epoch: 9/1000: : 3090it [00:46, 66.52it/s, loss=7.573, bar=1.439, pos=1.626, token=2.410, dur=1.803, phrase=0.295]                           

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 9/1000:   0%|                                                                                               | 0/193 [00:00<?, ?it/s][A
validation epoch: 9/1000:   0%|                       | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 9/1000:   1%|▏              | 2/193 [00:00<00:26,  7.24it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 9/1000:   1%|▏              | 2/193 [00:00<00:26,  7.24it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 9/1000:   2%|▎              | 4/193 [00:00<00:26,  7.24it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.113148598532074, Validation Running Loss = 7.776168944304471


training epoch: 10/1000: : 3090it [00:45, 67.55it/s, loss=7.598, bar=1.453, pos=1.650, token=2.406, dur=1.799, phrase=0.290]                          

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 10/1000:   0%|                                                                                              | 0/193 [00:00<?, ?it/s][A
validation epoch: 10/1000:   0%|                      | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 10/1000:   1%|▏             | 2/193 [00:00<00:29,  6.42it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 10/1000:   1%|▏             | 2/193 [00:00<00:29,  6.42it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 10/1000:   2%|▎             | 4/193 [00:00<00:29,  6.42it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.114177844902459, Validation Running Loss = 7.776168944304471


training epoch: 11/1000: : 3090it [00:42, 71.96it/s, loss=7.527, bar=1.451, pos=1.612, token=2.389, dur=1.762, phrase=0.312]                          

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 11/1000:   0%|                                                                                              | 0/193 [00:00<?, ?it/s][A
validation epoch: 11/1000:   0%|                      | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 11/1000:   1%|▏             | 2/193 [00:00<00:30,  6.25it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 11/1000:   1%|▏             | 2/193 [00:00<00:30,  6.25it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 11/1000:   2%|▎             | 4/193 [00:00<00:30,  6.25it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.113624964717136, Validation Running Loss = 7.776168944304471


training epoch: 12/1000: : 3090it [00:42, 72.41it/s, loss=7.621, bar=1.434, pos=1.625, token=2.460, dur=1.790, phrase=0.311]                          

  0%|                                                                                                                         | 0/193 [00:00<?, ?it/s][A
validation epoch: 12/1000:   0%|                                                                                              | 0/193 [00:00<?, ?it/s][A
validation epoch: 12/1000:   0%|                      | 0/193 [00:00<?, ?it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 12/1000:   1%|▏             | 2/193 [00:00<00:30,  6.26it/s, loss=9.238, bar=2.581, pos=1.936, token=3.006, dur=1.420, phrase=0.295][A
validation epoch: 12/1000:   1%|▏             | 2/193 [00:00<00:30,  6.26it/s, loss=6.542, bar=1.460, pos=1.020, token=2.509, dur=1.169, phrase=0.384][A
validation epoch: 12/1000:   2%|▎             | 4/193 [00:00<00:30,  6.26it/s,

Validation loss decreased (7.776169 --> 7.776169).  Saving model ...
Training Runinng Loss = 8.111280704547672, Validation Running Loss = 7.776168944304471


training epoch: 13/1000:  52%|██████▋      | 796/1545 [00:11<00:11, 66.81it/s, loss=7.729, bar=1.759, pos=1.422, token=2.515, dur=1.592, phrase=0.442]
  2%|▉                                      | 24/1000 [09:42<6:35:03, 24.29s/it]

KeyboardInterrupt

