In [1]:
from data_utils import CSGridMLMDataset, CSGridMLM_collate_fn
from GridMLM_tokenizers import CSGridMLMTokenizer
import os
import numpy as np
from torch.utils.data import DataLoader
from models import DualGridMLMMelHarm, SEModular
import torch
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
curriculum_type = 'f2f'
exponent = 5
subfolder = 'Q4_L80_bar_PC'
train_dir = '/media/maindisk/data/synthetic_CA_train'
val_dir = '/media/maindisk/data/synthetic_CA_test'
device_name = 'cpu'
epochs = 200
lr = 1e-4
batchsize = 8

In [3]:
total_stages = None if curriculum_type == 'f2f' else 10
condition_dim = None if 'bar' in subfolder else 16
trainable_pos_emb = False

grid_lenght = int(subfolder.split('_L')[1].split('_')[0])
tokenizer = CSGridMLMTokenizer(
    fixed_length=grid_lenght,
    quantization='16th' if 'Q16' in subfolder else '4th',
    intertwine_bar_info='bar' in subfolder,
    trim_start=False,
    use_pc_roll='PC' in subfolder,
    use_full_range_melody='FR' in subfolder
)

In [4]:
train_dataset = CSGridMLMDataset(train_dir, tokenizer, name_suffix=subfolder)
val_dataset = CSGridMLMDataset(val_dir, tokenizer, name_suffix=subfolder)

trainloader = DataLoader(train_dataset, batch_size=batchsize, shuffle=True, collate_fn=CSGridMLM_collate_fn)
valloader = DataLoader(val_dataset, batch_size=batchsize, shuffle=False, collate_fn=CSGridMLM_collate_fn)

Loading data file.
Loading data file.


In [5]:
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)
# end device selection

loss_fn=CrossEntropyLoss(ignore_index=-100)

In [6]:
model = DualGridMLMMelHarm(
    chord_vocab_size=len(tokenizer.vocab),
    d_model=512,
    nhead=8,
    num_layers_mel=8,
    num_layers_harm=8,
    melody_length=grid_lenght,
    harmony_length=grid_lenght,
    pianoroll_dim=tokenizer.pianoroll_dim,
    device=device,
)
# model = SEModular(
#     chord_vocab_size=len(tokenizer.vocab),
#     d_model=512,
#     nhead=8,
#     num_layers=8,
#     grid_length=grid_lenght,
#     pianoroll_dim=tokenizer.pianoroll_dim,
#     condition_dim=condition_dim,  # if not None, add a condition token of this dim at start
#     unmasking_stages=total_stages,  # if not None, use stage-based unmasking
#     trainable_pos_emb=trainable_pos_emb,
#     device=device,
# )
model.to(device)
optimizer = AdamW(model.parameters(), lr=lr)

In [7]:
batch = next(iter(trainloader))

In [8]:
harmony_gt = batch["harmony_ids"].to(device)
print(harmony_gt)

tensor([[  6, 210, 269, 329, 152,   6, 124, 124, 269,   7,   6, 329,   7, 210,
         269,   6,  66, 124, 269,   7,   6,  66, 124, 210, 329,   6, 269, 329,
          66,   7,   6, 329, 210,  66,   7,   6, 152, 269, 210,   7,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1],
        [  6,  66, 124, 329, 124,   6, 210, 269, 124, 269,   6, 152, 329, 210,
          66,   6, 152, 329, 124, 210,   6,   7,  66, 329, 269,   6, 329, 269,
         329, 152,   6,   7, 210, 329, 329,   6, 152, 210,   7, 152,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1],
        [  6,  66, 329, 152, 210,   6, 124,  66,  66,   7,   6, 269, 152, 269,
         21

In [9]:
def full_to_partial_masking(
        harmony_tokens,
        mask_token_id,
        num_visible=0,
        bar_token_id=None
    ):
    """
    Generate visible input and denoising target for diffusion-style training.

    Args:
        harmony_tokens (torch.Tensor): Tensor of shape (B, L) containing target harmony token ids.
        stage (int): Current training stage (0 to total_stages - 1).
        total_stages (int): Total number of diffusion stages.
        mask_token_id (int): The token ID used to mask hidden positions in visible_harmony.
        device (str or torch.device): Target device.

    Returns:
        visible_harmony (torch.Tensor): Tensor of shape (B, L) with visible tokens (others masked).
        denoising_target (torch.Tensor): Tensor of shape (B, L) with tokens to predict (others = -100).
    """
    device = harmony_tokens.device
    B, L = harmony_tokens.shape

    visible_harmony = torch.full_like(harmony_tokens, fill_value=mask_token_id)
    denoising_target = torch.full_like(harmony_tokens, fill_value=-100)  # -100 is ignored by CrossEntropyLoss

    if bar_token_id is not None:
        # Create a mask for bar token positions
        bar_mask = (harmony_tokens == bar_token_id)
        # Put bar tokens in visible_harmony (always unmasked)
        visible_harmony[bar_mask] = bar_token_id
        # # Also include them in the denoising target (so model predicts them too)
        denoising_target[bar_mask] = bar_token_id
    
    perm = torch.randperm(L, device=device)

    visible_idx = perm[:num_visible]
    predict_idx = perm[num_visible:]  # predict all remaining
    # print('visible_idx: ', visible_idx)
    # print('predict_idx: ', predict_idx)

    visible_harmony[:, visible_idx] = harmony_tokens[:, visible_idx]
    denoising_target[:, predict_idx] = harmony_tokens[:, predict_idx]

    # visible_harmony = harmony_tokens.clone()
    # visible_harmony[:, :] = mask_token_id
    # # visible_harmony[:, 0:10] = mask_token_id
    # denoising_target = harmony_tokens.clone()  # -100 is ignored by CrossEntropyLoss
    # # denoising_target[:, 10:] = -100

    return visible_harmony, denoising_target
# end full_to_partial_masking

In [10]:
step = 0
total_steps = 1000
percent_visible = min(1.0, (step+1)/total_steps)**exponent  # 5th power goes around half way near zero
L = harmony_gt.shape[1]
num_visible = min( int(L * percent_visible), L-1 )  # ensure at least one token is predicted
harmony_input, harmony_target = full_to_partial_masking(
    harmony_gt,
    tokenizer.mask_token_id,
    num_visible,
    bar_token_id=tokenizer.bar_token_id
)
stage_indices = None
conditioning_vec = None

In [11]:
print(harmony_input[0])
print(harmony_gt[0])

tensor([6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 5, 5, 5,
        5, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5])
tensor([  6, 210, 269, 329, 152,   6, 124, 124, 269,   7,   6, 329,   7, 210,
        269,   6,  66, 124, 269,   7,   6,  66, 124, 210, 329,   6, 269, 329,
         66,   7,   6, 329, 210,  66,   7,   6, 152, 269, 210,   7,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1])


In [25]:
logits = model(
    # melody_grid.to(device),
    harmony_gt.to(device),
    harmony_input.to(device),
    conditioning_vec,
    stage_indices
)

In [13]:
print(logits[0].shape)

torch.Size([80, 355])


In [26]:
loss = loss_fn(logits.view(-1, logits.size(-1)), harmony_target.view(-1))
print(loss)

tensor(2.6083, grad_fn=<NllLossBackward0>)


In [27]:
optimizer.zero_grad()
loss.backward()
optimizer.step()