In [1]:
import torch
from torcheval.metrics.text import Perplexity
import random
from tqdm import tqdm
from data_utils import compute_normalized_token_entropy
import random
import numpy as np
import os
from transformers import get_cosine_schedule_with_warmup
from data_utils import CSGridMLMDataset, CSGridMLM_collate_fn
from GridMLM_tokenizers import CSGridMLMTokenizer
from torch.utils.data import DataLoader
from models import DualGridMLMMelHarm
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from train_utils import train_with_curriculum, apply_masking

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = CSGridMLMTokenizer(
    fixed_length=80,
    quantization='4th',
    intertwine_bar_info=True,
    trim_start=False,
    use_pc_roll=True,
    use_full_range_melody=False
)

In [3]:
train_dir = '/media/maindisk/data/synthetic_CA_train'
val_dir = '/media/maindisk/data/synthetic_CA_test'
batchsize = 2

In [4]:
train_dataset = CSGridMLMDataset(train_dir, tokenizer, name_suffix='DE')
val_dataset = CSGridMLMDataset(val_dir, tokenizer, name_suffix='DE')

trainloader = DataLoader(train_dataset, batch_size=batchsize, shuffle=True, collate_fn=CSGridMLM_collate_fn)
valloader = DataLoader(val_dataset, batch_size=batchsize, shuffle=False, collate_fn=CSGridMLM_collate_fn)

Loading data file.
Loading data file.


In [5]:
device_name = 'cuda:0'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

In [6]:
loss_fn=CrossEntropyLoss(ignore_index=-100)
lr = 1e-5

In [7]:
total_stages = 10

model = DualGridMLMMelHarm(
    chord_vocab_size=len(tokenizer.vocab),
    d_model=512,
    nhead=8,
    num_layers_mel=8,
    num_layers_harm=8,
    melody_length=80,
    harmony_length=80,
    max_stages=total_stages,
    pianoroll_dim=tokenizer.pianoroll_dim,
    device=device
)
model.to(device)
optimizer = AdamW(model.parameters(), lr=lr)

In [8]:
batch = next(iter(trainloader))
model.train()

DualGridMLMMelHarm(
  (melody_proj): Linear(in_features=13, out_features=512, bias=True)
  (harmony_embedding): Embedding(355, 512)
  (stage_embedding): Embedding(10, 64)
  (stage_proj): Linear(in_features=576, out_features=512, bias=True)
  (melody_encoder): SimpleTransformerStack(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (dropout2): Dropout(p=0.3, inplace=False)
      )
    )
  )
  (harmony_encoder): HarmonyTransformerSta

In [9]:
melody_grid = batch["pianoroll"].to(device)           # (B, 256, 100)
harmony_gt = batch["input_ids"].to(device)         # (B, 256)

In [10]:
print(melody_grid.shape)
print(harmony_gt.shape)

torch.Size([2, 80, 13])
torch.Size([2, 80])


In [11]:
mask_token_id = tokenizer.mask_token_id
bar_token_id = tokenizer.bar_token_id
curriculum_type = 'random'

In [12]:
rets = apply_masking(
    harmony_gt,
    mask_token_id,
    total_stages=total_stages,
    curriculum_type=curriculum_type,
    bar_token_id=bar_token_id
)

visible_idx:  tensor([], device='cuda:0', dtype=torch.int64)
predict_idx:  tensor([12, 18, 49, 72, 21,  9, 24, 52], device='cuda:0')
visible_idx:  tensor([22,  0, 76, 31,  9, 71, 41, 77, 49, 66, 48,  8, 60, 47, 15, 29, 55, 28,
        26, 38, 72, 20, 19, 17,  6, 11, 27, 56, 59,  4, 53, 14, 50, 64, 42, 61,
        62, 51, 67, 21, 74, 36, 69, 30, 25, 75,  2, 35, 44, 58, 16, 12, 54, 65,
         5, 34, 45,  3, 57, 78, 63, 13, 68, 70], device='cuda:0')
predict_idx:  tensor([22,  0, 76, 31,  9, 71, 41, 77, 49, 66, 48,  8, 60, 47, 15, 29, 55, 28,
        26, 38, 72, 20, 19, 17,  6, 11, 27, 56, 59,  4, 53, 14, 50, 64, 42, 61,
        62, 51, 67, 21, 74, 36, 69, 30, 25, 75,  2, 35, 44, 58, 16, 12, 54, 65,
         5, 34, 45,  3, 57, 78, 63, 13, 68, 70, 73,  1, 18, 40, 23, 37, 46, 10],
       device='cuda:0')


In [13]:
harmony_input, harmony_target, stage_indices = rets[0], rets[1], rets[2]

In [14]:
# print(harmony_input[0,:])
# print(harmony_target[0,:])
# print(stage_indices)
for i in range(harmony_input.shape[1]):
    print(harmony_gt[0][i].item(), harmony_input[0,i].item(), harmony_target[0,i].item())

6 6 -100
210 5 -100
269 5 -100
210 5 -100
7 5 -100
6 6 -100
152 5 -100
269 5 -100
269 5 -100
7 5 7
6 6 -100
7 5 -100
329 5 329
7 5 -100
210 5 -100
6 6 -100
152 5 -100
7 5 -100
7 5 7
124 5 -100
6 6 -100
329 5 329
66 5 -100
210 5 -100
7 5 7
6 6 -100
269 5 -100
7 5 -100
269 5 -100
329 5 -100
6 6 -100
210 5 -100
210 5 -100
124 5 -100
7 5 -100
6 6 -100
152 5 -100
152 5 -100
66 5 -100
124 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 1
1 5 -100
1 5 -100
1 5 1
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 1
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100
1 5 -100


In [15]:
print(tokenizer.vocab)

{'<unk>': 0, '<pad>': 1, '<s>': 2, '</s>': 3, '<nc>': 4, '<mask>': 5, '<bar>': 6, 'C:maj': 7, 'C:min': 8, 'C:aug': 9, 'C:dim': 10, 'C:sus4': 11, 'C:sus2': 12, 'C:7': 13, 'C:maj7': 14, 'C:min7': 15, 'C:minmaj7': 16, 'C:maj6': 17, 'C:min6': 18, 'C:dim7': 19, 'C:hdim7': 20, 'C:maj9': 21, 'C:min9': 22, 'C:9': 23, 'C:min11': 24, 'C:11': 25, 'C:maj13': 26, 'C:min13': 27, 'C:13': 28, 'C:1': 29, 'C:5': 30, 'C': 31, 'C:7(b9)': 32, 'C:7(#9)': 33, 'C:7(#11)': 34, 'C:7(b13)': 35, 'C#:maj': 36, 'C#:min': 37, 'C#:aug': 38, 'C#:dim': 39, 'C#:sus4': 40, 'C#:sus2': 41, 'C#:7': 42, 'C#:maj7': 43, 'C#:min7': 44, 'C#:minmaj7': 45, 'C#:maj6': 46, 'C#:min6': 47, 'C#:dim7': 48, 'C#:hdim7': 49, 'C#:maj9': 50, 'C#:min9': 51, 'C#:9': 52, 'C#:min11': 53, 'C#:11': 54, 'C#:maj13': 55, 'C#:min13': 56, 'C#:13': 57, 'C#:1': 58, 'C#:5': 59, 'C#': 60, 'C#:7(b9)': 61, 'C#:7(#9)': 62, 'C#:7(#11)': 63, 'C#:7(b13)': 64, 'D:maj': 65, 'D:min': 66, 'D:aug': 67, 'D:dim': 68, 'D:sus4': 69, 'D:sus2': 70, 'D:7': 71, 'D:maj7': 72,