### This notebook illustrates how to use customized control function to generate stylized music
Our approach enables controllable stylization in music generation. The sampling control is able to ensure that all generated notes strictly adhere to the target musical style's scale. This allows the model to generate music in specific styles — even those that were not present in the training data.

Below, we demonstrate several examples of style-controlled music generation for:

- Dorian Mode: (with scale being A-B-C-D-E-F#-G)
- Chinese Style: (with scale being C-D-E-G-A)

In [1]:
import sys
import os
import torch
from model import init_ldm_model
from model.model_sdf import Diffpro_SDF
from model.sampler_sdf import SDFSampler

import numpy as np
import pickle
from train.train_params import params_combined_cond
from generation_utils.fine_grained_control import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### set the global parameters here

In [2]:
STYLE = "chinese" # "dorian" or "chinese"
NUM_SAMPLES = 10 # number of samples to generate
RHYTHM_CONTROL = "soft_rhythm" # This parameter can be "hard_rhythm" or "soft_rhythm". "hard_rhythm" means that we strictly follow the rhythm control in sampling process; "soft_rhythm" means that we only control the rhythm in training but not in sampling.
SEPARATE_MELODY_ACCOMPANIMENT = False # True if want to use a model that generates accompaniment conditioning on chord and melody; False if want to train a model that generates melody and accompaniment conditioning on chord
MODEL_PATH = 'results/model-combine_melody_accompaniment-/example/chkpts/weights_example.pt' # path of the model
CONDITION_DATA_PATH = 'data/train_test_slices/test_slices_combine_melody_accompaniment.pkl' # path of the condition data
SAVE_CHORD_IN_MIDI = False # whether to save the chords in midi file


### customized control functions

In [3]:

def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, sustain=None):
    '''
    piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch
    num_notes_onset: a tensor with shape (batch_size, length)
    mask_full: a tensor with shape the same as piano_roll
    '''
    print("num_notes_onset_edit_rhythm", num_notes_onset)
    # print("sustain_edit_rhythm", sustain)
    ########## for those greater than the threshold, if num of notes exceed num_notes[i], 
    ########## will keep the first ones and set others to threshold
    print("editing rhythm")
    # we only edit onset
    onset_roll = piano_roll_full[:,0,:,:]
    sustain_roll = piano_roll_full[:,1,:,:]
    mask = mask_full[:,0,:,:]
    shape = onset_roll.shape

    onset_roll = onset_roll.reshape(-1,shape[-1])
    mask = mask.reshape(-1,shape[-1])
    num_notes = num_notes_onset.reshape(-1)
    print(num_notes)

    reduce_note_threshold = 0.499
    increase_note_threshold = 0.501

    # Initialize a tensor to store the modified values
    final_onset_roll = onset_roll.clone()
    threshold_mask = onset_roll > reduce_note_threshold
    # Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
    values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))

    # Get the top num_notes.max() values for each row
    num_notes_max = int(num_notes.max().item())  # Maximum number of notes needed in any row
    print("num_notes_max", num_notes_max)
    print("values_above_threshold", values_above_threshold.shape)
    topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)

    # Create a mask for the top num_notes[i] values for each row
    col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max)
    topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf"))

    # Set all values greater than reduce_note_threshold to reduce_note_threshold initially
    final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold

    # Create a flattened index to scatter the top values back into final_onset_roll
    flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices)
    flat_row_indices = flat_row_indices[topk_mask]

    # Gather the valid topk_indices and corresponding values
    valid_topk_indices = topk_indices[topk_mask]
    valid_topk_values = topk_values[topk_mask]

    # Use scatter to place the top num_notes[i] values back to their original positions
    final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)


    # Count how many values >= increase_note_threshold for each row
    threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1)
    greater_than_threshold2_count = threshold_mask_2.sum(dim=1)

    # For those rows, find the remaining number of values needed to be set to increase_note_threshold
    remaining_needed = num_notes - greater_than_threshold2_count
    remaining_needed_max = int(remaining_needed.max().item())

    # Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
    values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
    topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)

    # Mask to only adjust the needed number of values in each row
    col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max)
    adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))

    # Flatten row indices for the new top-k below increase_note_threshold
    flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
    flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]

    # Gather the valid indices and set them to increase_note_threshold
    valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]

    # Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
    final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device))
    final_onset_roll = final_onset_roll.reshape(shape)
    piano_roll_full[:,0,:,:] = final_onset_roll

    if sustain is not None:
        print("edit sustain")
        final_sustain_roll = sustain_roll.clone()
        final_sustain_roll[:,0,:] = torch.clamp(final_sustain_roll[:,0,:], max=0.499)
        for t in range(1,shape[1]):
            current_sustain_roll = final_sustain_roll[:,t,:].clone()
            current_num_sustain = sustain[:,t]
            allowed_sustain_positions = torch.max(final_onset_roll[:,t-1,:], final_sustain_roll[:,t-1,:])>0.5

            current_sustain_roll = torch.where((~allowed_sustain_positions)&(current_sustain_roll>0.5), 0.499, current_sustain_roll)

            threshold_mask = current_sustain_roll > 0.5
            # Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
            values_above_threshold = torch.where(threshold_mask & (allowed_sustain_positions == 1), current_sustain_roll, torch.tensor(-float('inf')).to(current_sustain_roll.device))

            # Get the top num_notes.max() values for each row
            num_notes_max = int(current_num_sustain.max().item())  # Maximum number of notes needed in any row
            topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)

            # Create a mask for the top num_notes[i] values for each row
            col_indices = torch.arange(num_notes_max, device=current_sustain_roll.device).expand(len(current_sustain_roll), num_notes_max)
            topk_mask = (col_indices < current_num_sustain.unsqueeze(1)) & (topk_values > -float("inf"))

            # Set all values greater than reduce_note_threshold to reduce_note_threshold initially
            current_sustain_roll[threshold_mask & (allowed_sustain_positions == 1)] = reduce_note_threshold

            # Create a flattened index to scatter the top values back into final_onset_roll
            flat_row_indices = torch.arange(current_sustain_roll.size(0), device=current_sustain_roll.device).unsqueeze(1).expand_as(topk_indices)
            flat_row_indices = flat_row_indices[topk_mask]

            # Gather the valid topk_indices and corresponding values
            valid_topk_indices = topk_indices[topk_mask]
            valid_topk_values = topk_values[topk_mask]

            # Use scatter to place the top num_notes[i] values back to their original positions
            current_sustain_roll = current_sustain_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)

            # Count how many values >= increase_note_threshold for each row
            threshold_mask_2 = (current_sustain_roll >= increase_note_threshold)&(allowed_sustain_positions==1)
            greater_than_threshold2_count = threshold_mask_2.sum(dim=1)

            # For those rows, find the remaining number of values needed to be set to increase_note_threshold
            remaining_needed = current_num_sustain - greater_than_threshold2_count
            remaining_needed_max = int(remaining_needed.max().item())

            # Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
            values_below_threshold2 = torch.where((current_sustain_roll < increase_note_threshold)&(allowed_sustain_positions==1), current_sustain_roll, torch.tensor(-float('inf')).to(current_sustain_roll.device))
            topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)

            # Mask to only adjust the needed number of values in each row
            col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=current_sustain_roll.device).expand(len(current_sustain_roll), remaining_needed_max)
            adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))

            # Flatten row indices for the new top-k below increase_note_threshold
            flat_row_indices_below_threshold2 = torch.arange(current_sustain_roll.size(0), device=current_sustain_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
            flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]

            # Gather the valid indices and set them to increase_note_threshold
            valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]

            # Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
            current_sustain_roll = current_sustain_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=current_sustain_roll.device))
            final_sustain_roll[:,t,:] = current_sustain_roll
        piano_roll_full[:,1,:,:] = final_sustain_roll
    return piano_roll_full


def X0EditFunc_customized(x0, background_condition, sampler_device=device):
    if STYLE == "chinese":
        print("Chinese style")
        chroma = torch.tensor([1,0,1,0,0,1,0,1,0,0,1,0], device=device) # chroma，从A开始
        chroma = chroma.repeat(64//12+1)[:64]

        seven_notes_chroma_ours = chroma.view(1, 1, 1, 64)

        # Use expand to replicate the tensor
        seven_notes_chroma_ours = seven_notes_chroma_ours.expand(x0.shape[0], x0.shape[1], x0.shape[2], 64)

    elif STYLE == "dorian":
        print("dorian style")
        
        chroma = torch.tensor([1,0,1,1,0,1,0,1,0,1,1,0], device=device) # chroma，从A开始
        chroma = chroma.repeat(64//12+1)[:64]

        seven_notes_chroma_ours = chroma.view(1, 1, 1, 64)

        # Use expand to replicate the tensor
        seven_notes_chroma_ours = seven_notes_chroma_ours.expand(x0.shape[0], x0.shape[1], x0.shape[2], 64)
    
    else:
        print("regular style")
        maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
        maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1))
        min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
        min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1))
    
        # all chords, with rotation
        maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
        min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
        
        # combine all chords
        # chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types, 
        # 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes
        chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0)

        # if using null rhythm condition, have to convert -2 to 1 and -1 to 0
        if background_condition[:,:2,:,:].min()<0:
            correct_chord_condition = -background_condition[:,:2,:,:]-1
        else:
            correct_chord_condition = background_condition[:,:2,:,:]
        merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond
        chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond
        shape = chd_chroma_ours.shape
        chd_chroma_ours = chd_chroma_ours.reshape(-1,64)
        matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1)
        seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape)
        seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1))

        no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1)
        seven_notes_chroma_ours[no_chd_match] = 1.

    # print(seven_notes_chroma_ours[0,0,:,:].sum(axis=0)[3:63].reshape(-1, 12))
    # print(background_condition[0,:,0,:])
    
    # original rhythm generated by the model
    rhythm = (x0[:,0,:,:]>0.5).sum(axis=-1)
    sustain = (x0[:,1,:,:]>0.5).sum(axis=-1)
    # print("rhythms\n\n\n", rhythm.shape, rhythm)

    #### edit notes based on chroma
    x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0)

    #### edit rhythm
    # num_onset_notes is rhythm based on our condition
    if background_condition[:,0,:,:].min()>=0:
        num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1)
    else: # this means using null rhythm
        num_onset_notes = rhythm
    # print(num_onset_notes.shape, num_onset_notes)

    
    zero_mask = (num_onset_notes == 0)
    # Create an array to hold the result
    B = torch.zeros_like(num_onset_notes)
    # Iterate over each row
    for i in range(num_onset_notes.shape[0]):
        # Use torch.cummax to forward-fill non-zero values
        filled_row, _ = torch.cummax(num_onset_notes[i], dim=0)
        # In the result array, place the filled values where the original was zero
        B[i] = torch.where(zero_mask[i], filled_row, torch.tensor(0, dtype=num_onset_notes.dtype))
    
    # 如果希望严格按照我们的rhythm控制来，选这一行
    if RHYTHM_CONTROL == "hard_rhythm":
        x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, B)
    else:
        # 如果希望按掉多少音就拔起来多少音，相当于保留模型自己生成的rhythm选这一行
        x0 = edit_rhythm(x0, rhythm, seven_notes_chroma_ours, sustain)
    return x0

### Define the background condition.

In [4]:
CHORD_DICTIONARY = { # 这个dictionary是从c到B
    "C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]),
    "C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]),
    "D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]),
    "Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]),
    "E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]),
    "F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]),
    "F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]),
    "G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]),
    "Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]),
    "A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]),
    "Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]),
    "B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]),

    "c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]),
    "c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]),
    "d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]),
    "eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]),
    "e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]),
    "f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]),
    "f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]),
    "g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]),
    "g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]),
    "a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]),
    "bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]),
    "b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]),

    "d:minor7": np.array([1,0,1,0,0,1,0,0,0,1,0,0]),
    "G:dominant7": np.array([0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1]),
    "C:dominant7": np.array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0]),
    "A:dominant7": np.array([0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0]),
    "D:dominant7": np.array([0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1]),
    "F:dominant7": np.array([1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]),
    "C:major7": np.array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1]), 
    "F:major7": np.array([1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0]),
    "bE:major7": np.array([0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0]),
    "bB:major7": np.array([0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0]), 

    "a:minor7": np.array([1,0,0,0,1,0,0,1,0,1,0,0]),
    "c:minor7": np.array([1,0,0,1,0,0,0,1,0,0,1,0]),
    "b:minor7b": np.array([0,0,1,0,0,1,0,0,0,1,0,1]),
    "e:minor7": np.array([0,0,1,0,1,0,0,1,0,0,0,1]),
}

def adjust_rhythm_string(s):
    # Truncate if longer than 16 characters
    if len(s) > 16:
        return s[:16]
    # Pad with zeros if shorter than 16 characters
    else:
        return s.ljust(16, '0')
def rhythm_string_to_array(s):
    # Ensure the string is 16 characters long
    if "null rhythm" in s:
        print("has null")
        return np.array([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1])*np.array([3,1,2,1,3,1,2,1,3,1,2,1,3,1,2,1])
    s = s[:16].ljust(16, '0')  # Truncate or pad with '0' to make it 16 characters
    # Convert to numpy array, treating non-'0' as '1'
    arr = np.array([int(char) for char in s], dtype=int)

    print(arr)
    return arr

In [None]:
from generation_utils.fine_grained_control import circular_extend

num_samples = NUM_SAMPLES
'''
Here we define the chords and rhythms for each style.
'''
if STYLE=="chinese":
    chords = ["C:major", "F:major", "G:major", "a:minor"] # here we list the chords for each measure
    test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[chords[0]], (16, 1)),
                                np.tile(CHORD_DICTIONARY[chords[1]], (16, 1)), 
                                np.tile(CHORD_DICTIONARY[chords[2]], (16, 1)), 
                                np.tile(CHORD_DICTIONARY[chords[3]], (16, 1)),
                                ]) # each chord is repeated 16 times because the granularity is 16 per beat
    rhythms = ["3011211120102011", "3011201020101000", "3011201120101010", "1011111020001000"] # numbers are numbers of onsets in each 16th note
elif STYLE=="dorian":
    chords = ["a:minor", "e:minor","a:minor", "a:minor", "D:major", "a:minor", "D:major"] # here we list the chords
    test_chd_roll = np.concatenate([np.tile(CHORD_DICTIONARY[chords[0]], (16, 1)), 
                                np.tile(CHORD_DICTIONARY[chords[1]], (8, 1)), 
                                np.tile(CHORD_DICTIONARY[chords[2]], (8, 1)),
                                np.tile(CHORD_DICTIONARY[chords[3]], (8, 1)),
                                np.tile(CHORD_DICTIONARY[chords[4]], (8, 1)),
                                np.tile(CHORD_DICTIONARY[chords[5]], (8, 1)),
                                np.tile(CHORD_DICTIONARY[chords[6]], (8, 1))])
    rhythms = ["3010101111111000", "1111101111111000", "1010101110101011", "1011101010101000"]
else:
    raise ValueError("Invalid style")

print("test_chd_roll", test_chd_roll.shape)
chd_roll = np.concatenate([test_chd_roll[np.newaxis,:,:], test_chd_roll[np.newaxis,:,:]], axis=0)

chd_roll = circular_extend(chd_roll)
chd_roll = -chd_roll-1

real_chd_roll = chd_roll

melody_roll = -np.ones_like(chd_roll)


rhythm_full = []
for i in range(len(rhythms)):
    rhythm = adjust_rhythm_string(rhythms[i])
    rhythm = rhythm_string_to_array(rhythm)
    rhythm_full.append(rhythm)
rhythm_full = np.concatenate(rhythm_full, axis=0)

onset_roll = test_chd_roll*rhythm_full[:, np.newaxis]
sustain_roll = np.zeros_like(onset_roll)
no_onset_pos = np.all(onset_roll == 0, axis=-1)
sustain_roll[no_onset_pos] = test_chd_roll[no_onset_pos]

real_chd_roll = np.concatenate([onset_roll[np.newaxis,:,:], sustain_roll[np.newaxis,:,:]], axis=0)
real_chd_roll = circular_extend(real_chd_roll)

for i in range(len(rhythms)):
    if rhythms[i] == "null rhythm":
        real_chd_roll[0,i*16:(i+1)*16,:] = -real_chd_roll[0,i*16:(i+1)*16,:].clip(max=1)-1
        real_chd_roll[1,i*16:(i+1)*16,:] = real_chd_roll[0,i*16:(i+1)*16,:]

background_condition = np.concatenate([real_chd_roll, chd_roll], axis=0)
background_condition = torch.Tensor(np.tile(background_condition, (num_samples,1,1,1))).to(device)

test_chd_roll (64, 12)
[3 0 1 1 2 1 1 1 2 0 1 0 2 0 1 1]
[3 0 1 1 2 0 1 0 2 0 1 0 1 0 0 0]
[3 0 1 1 2 0 1 1 2 0 1 0 1 0 1 0]
[1 0 1 1 1 1 1 0 2 0 0 0 1 0 0 0]


### start generation

In [6]:
# load the model
ldm_model = init_ldm_model(params_combined_cond, debug_mode=False)
model = Diffpro_SDF.load_trained(ldm_model, MODEL_PATH).to(device)
sampler = SDFSampler(model.ldm, 64, 64, is_autocast=False, device=device, debug_mode=False)

  trained_leaner = torch.load(chkpt_fpath, map_location=device)
  self.autocast = torch.cuda.amp.autocast(enabled=is_autocast)


In [7]:
output_x = sampler.generate(background_cond=background_condition, batch_size=background_condition.shape[0], 
                            same_noise_all_measure=False, X0EditFunc=X0EditFunc_customized, use_classifier_free_guidance=True,
                            use_melody=False)
output_x = torch.clamp(output_x, min=0, max=1)
output_x = output_x.cpu().numpy()

[10, 2, 64, 64]


p_sample
Chinese style
num_notes_onset_edit_rhythm tensor([[15, 10,  3,  6, 11,  4,  4,  5, 10,  3,  6,  5,  7,  3,  3, 12,  8,  4,
          6,  9, 11,  6,  7,  3,  6,  2,  8,  0,  9,  3,  5,  9, 12,  5,  9, 10,
          9,  2,  9,  6, 10,  8,  9,  4,  8,  3,  7,  4,  4,  1,  3, 11,  4,  6,
          7,  4, 12,  2,  2,  3,  2,  0, 12, 23],
        [15, 14,  5,  8, 12,  8,  9,  9, 12,  7,  7,  0, 14,  4, 10, 11, 16,  3,
          8, 12, 13,  2,  6,  3, 12,  2,  9,  5,  9,  5,  3,  4, 15,  2, 12, 12,
         12,  6,  4, 10, 13,  5,  4,  1,  8,  4, 10,  9,  3,  4,  7,  9,  9,  9,
          9,  6, 10,  3,  3,  5,  9,  4, 12, 22],
        [15,  9,  9,  6,  8,  5, 10,  5, 10,  4,  7,  2, 10,  3, 11, 10, 11,  4,
         10,  9,  8,  5,  9,  4,  8,  4,  7,  2,  6,  4,  2,  4, 13,  1,  8, 10,
         12,  5,  8,  6, 11,  4,  6,  4,  6,  4,  8,  4, 16,  3, 15,  6,  8,  9,
          8,  8,  6,  1,  3,  1,  5,  1, 13, 19],
        [15, 11, 10,  4,  8,  6, 11,  9, 10,  4,  3,  5, 11,  4,  7,  

In [8]:
from data.prepare_training_pianoroll.convert_to_midi import extend_piano_roll, piano_roll_to_midi, save_midi
import subprocess

full_piano_rolls = []
for i in range(output_x.shape[0]):
    full_roll = extend_piano_roll(output_x[i]) # accompaniment roll
    full_piano_rolls.append(full_roll[np.newaxis,:])
    chd_roll = -background_condition[i,2:4,:,:].cpu().numpy()-1
    chd_roll[0,:,:] = 0.
    chd_roll[0,0::4,:] = chd_roll[1,0::4,:]
    full_chd_roll = extend_piano_roll(chd_roll) # chord roll
    # full_lsh_roll = None
    # if background_cond.shape[1]>=6:
    #     if background_cond[:,4:6,:,:].min()>=0:
    #         full_lsh_roll = extend_piano_roll(background_cond[i,4:6,:,:].cpu().numpy())

    midi_file = piano_roll_to_midi(piano_roll=full_roll, chd_roll=None, melody_roll=None, bpm=80)
    os.makedirs(f'generated_samples/style_{STYLE}', exist_ok=True)
    filename = f"generated_samples/style_{STYLE}/sample_{i}.mid"
    save_midi(midi_file, filename)

    # convert midi to wav
    subprocess.Popen(['timidity',f"generated_samples/style_{STYLE}/sample_{i}.mid",'-Ow','-o',f"generated_samples/style_{STYLE}/sample_{i}.wav"]).communicate()

Playing generated_samples/style_chinese/sample_0.mid
MIDI file: generated_samples/style_chinese/sample_0.mid
Format: 1  Tracks: 2  Divisions: 220


Playing time: ~16 seconds
Notes cut: 0
Notes lost totally: 0
Playing generated_samples/style_chinese/sample_1.mid
MIDI file: generated_samples/style_chinese/sample_1.mid
Format: 1  Tracks: 2  Divisions: 220
Playing time: ~15 seconds
Notes cut: 0
Notes lost totally: 0
Playing generated_samples/style_chinese/sample_2.mid
MIDI file: generated_samples/style_chinese/sample_2.mid
Format: 1  Tracks: 2  Divisions: 220
Playing time: ~15 seconds
Notes cut: 0
Notes lost totally: 0
Playing generated_samples/style_chinese/sample_3.mid
MIDI file: generated_samples/style_chinese/sample_3.mid
Format: 1  Tracks: 2  Divisions: 220
Playing time: ~15 seconds
Notes cut: 0
Notes lost totally: 0
Playing generated_samples/style_chinese/sample_4.mid
MIDI file: generated_samples/style_chinese/sample_4.mid
Format: 1  Tracks: 2  Divisions: 220
Playing time: ~15 seconds
Notes cut: 0
Notes lost totally: 0
Playing generated_samples/style_chinese/sample_5.mid
MIDI file: generated_samples/style_chinese/sample_5.mid
Fo