This is a pseudo code for training simple deep models at a glance.

## Imports

In [None]:
import datetime
import os
import time
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, random_split
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device = ', device)

from data_processing.PianoRollsDataset import PianoRollsDataset
from models.utils import get_class_weight_for_perf, draw_3_fig, draw_test_on, print_confusion_matrix, get_class_weight, get_num_params
from models.deep_models import CnnModel, BilstmModel, CrnnModel, Transformer
from models.trainer import MySingleClassTrainer
from settings.annotations import LABEL_DCT
from settings.evaluation import PERFORMANCE_DF_COLS
from settings.s3_info import file_path as s3_file_path
from settings.s3_info import int_to_string as piece_name_in_str_s3
from settings.s3_info import meta_csv_path as meta_csv_file_s3
from settings.orchestration_info import file_path as orchestration_file_path
from settings.orchestration_info import int_to_string as piece_name_in_str
from settings.orchestration_info import meta_csv_path as meta_csv_file

converted_path = {
    's3': s3_file_path,
    'orchestration': orchestration_file_path
}

## Config

In [None]:
experiment_name = '2025-06-03_40-COMB-1-5'
save_at = os.path.join('.', 'results', experiment_name)
print(f"{save_at = }")
use_model = 'cnn'  # ["cnn", "lstm", "crnn"]
BLEND_MODE = 'COMB'  # ["SINGLE", "SUM", "COMB"]
k = 5                # if BLEND_MODE is 'COMB'
dropout = 0
num_layers = 1
exp_dict = {  # Determine how many information is fed to a model
    0: {
        'add_inst': 'target_inst',  # instruments name: [None, 'target_inst']
        'add_barlines': True,       # bar line positions
        'add_replaymtx':  True}     # replay matrix denoting the onset time of each note event (like a piano roll)
}
exp_id = 0
add_inst = exp_dict[exp_id]['add_inst']
add_barlines = exp_dict[exp_id]['add_barlines']
add_replaymtx = exp_dict[exp_id]['add_replaymtx']
print(f"Running experiment {exp_id} with parameters: {add_inst=}, {add_barlines=}, {add_beats=}, {add_rest=}, {add_replaymtx=}")

is_debug = True
k_fold = False
pretrained = None
PATIENT = 10
threshold = 0.5
BATCH_SIZE = 512
epoch_beg = 0
epoch_end = 2 if is_debug else 1
CONTEXT_MEASURES = 1 # i.e., "m" in paper
OTHER_INST = True # if OTHER_INST is True:
                    # 'SUM': sum all tracks => 2 channels, 
                    # 'COMB': choose k-1 other trakcs => k channels

# Calculate the number of input channels, because this number depends on 
# how many information a model is fed to
if OTHER_INST:
    if BLEND_MODE=='SUM': 
        input_ch = 2
        if add_replaymtx:
            input_ch += 2
    elif BLEND_MODE=='COMB': 
        input_ch = k
        if add_replaymtx:
            input_ch += k 
else:
    input_ch = 1
    if add_replaymtx:
        input_ch += 1
if add_inst=='target_inst':
    input_ch += 1
elif add_inst=='all':
    if BLEND_MODE=='COMB':
        input_ch += k
    elif BLEND_MODE=='SUM':
            input_ch += 2
    else:
        input_ch += 1
if add_barlines:
    input_ch += 1


## Dataset
Split the Orchestration dataset into training set, valid set, and testing set. 

The pieces used for testing set are first movement from Mozart's k. 504, Haydn's hob. 99, and Beethoven's Op. 21.

The rest pieces are further randomized and splited with the ratio of 8:2 to form training set and valid set.

In [None]:
test_pieces_idx = [0, 4, 9]

In [None]:
train_set = PianoRollsDataset(
    meta_csv_file, test=False,
    test_piece=test_pieces_idx,
    context=CONTEXT_MEASURES,
    other_inst=OTHER_INST,
    blend=BLEND_MODE,
    k=k,
    add_inst=add_inst,
    add_barlines=add_barlines,
    add_replaymtx=add_replaymtx,
    converted_path=converted_path,
    ds='orchestration',
)
test_set = PianoRollsDataset(
    meta_csv_file,
    test=True,
    test_piece=test_pieces_idx,
    context=CONTEXT_MEASURES,
    other_inst=OTHER_INST,
    blend=BLEND_MODE,
    k=k,
    add_inst=add_inst,
    add_barlines=add_barlines,
    add_replaymtx=add_replaymtx,
    converted_path=converted_path,
    ds='orchestration',
)
train_size = int(0.8 * len(train_set)) 
valid_size = len(train_set) - train_size
train_subset, valid_subset = random_split(train_set, [train_size, valid_size])
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_subset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

## Models
Create a model and calculate its size.

In [None]:
if use_model=='cnn':
    model = CnnModel(
        input_channels=input_ch,
        context_measures=CONTEXT_MEASURES,
    )
elif use_model=='lstm':
    model = BilstmModel(
        input_channels=input_ch,
        context_measures=CONTEXT_MEASURES,
        dropout=dropout,
        num_layers=num_layers,
    )
elif use_model=='crnn':
    model = CrnnModel(
        input_channels=input_ch,
        context_measures=CONTEXT_MEASURES,
    )
elif use_model=='transformer':
    model = Transformer(
    input_channels=input_ch,
)
if pretrained is not None:
    model.load_state_dict(torch.load(pretrained))

print(model.parameters())

class_count_train, class_weight_train = get_class_weight(train_set)
tmp = np.zeros(3)
for _ in range(8):
    tmp += class_weight_train[str(_)] * np.array(LABEL_DCT[_]['label'])
class_weight_train = [_/sum(tmp) for _ in tmp]
loss_func = torch.nn.BCELoss(
    weight=torch.tensor(class_weight_train, dtype=torch.float16).to(device)
)
optimizer = torch.optim.Adam(model.parameters())
model.to(device)

In [None]:
dflst, total_params, total_params_nograd = get_num_params(model, [], 0, 0)
df = pd.DataFrame(dflst)
print(f"{total_params=}, {total_params_nograd=}")
df.to_csv(f'{save_at}/model_params.csv')

## Trainer

In [None]:
# Need class weight during evaluaing the performace for each epoch
class_count, class_weight = get_class_weight_for_perf(train_set)
trainer = MySingleClassTrainer(
        model=model,
        optimizer=optimizer,
        scheduler=None,
        epoch_beg=epoch_beg,
        epoch_end=epoch_end,
        loss_func=loss_func,
        save_at=save_at,
        is_debug=is_debug,
        early_stopping={
                'patient': PATIENT,
                'criteria': float("-inf"),
                'beat_epoch': 0,
                'rule': "max"
            },
        device=device,
        threshold=threshold,
        reset_patient=True,
        freeze=False,
        do_save_current_stage=True,
        compute_loss_with_msk=True,
        return_performance=True,
        piece_name='',
        class_weight=class_weight,
        class_count=class_count
    )


## Train on Orchestration dataset

In [None]:
# Create an empty table to store performance for each epoch
performance = pd.DataFrame(
    columns=['epoch', 'train_acc_l', 'test_acc_l', 'train_acc_d', 'test_acc_d', 
    'train_precision_mel', 'train_recall_mel', 'test_precision_mel', 'test_recall_mel',
    'train_precision_rhythm', 'train_recall_rhythm', 'test_precision_rhythm', 'test_recall_rhythm',
    'train_precision_harm', 'train_recall_harm', 'test_precision_harm', 'test_recall_harm']
)

# Load the best model
best_epoch = [int(_.split('-best.pt')[0].split('epoch')[-1]) for _ in os.listdir(os.path.join(save_at, 'model')) if _.endswith('best.pt')]
all_epoch = [int(_.split('.pt')[0].split('-best')[0].split('epoch')[-1]) for _ in os.listdir(os.path.join(save_at, 'model')) if _.endswith('.pt')]
if best_epoch:
    best_epoch = max(best_epoch)
elif all_epoch:
    best_epoch = max(all_epoch)
else:
    best_epoch = None
if best_epoch is not None:
    model.load_state_dict(torch.load(os.path.join(
        save_at, 'model', f"epoch{str(best_epoch)}-best.pt"
    )))
    epoch_beg = max(epoch_beg, best_epoch)

# Record training time 
time_spend = []

In [None]:
for epoch in range(epoch_beg, epoch_end):
    print('*'*10, f'Epoch {epoch:4d}', '*'*10)
    trainer.epoch = epoch
    
    t2 = time.time()
    current_train_loss, current_train_acc, _train_performance, train_info_dict = trainer.train(train_loader, epoch)
    
    t3 = time.time()
    current_valid_loss, current_valid_acc, _valid_performance, valid_info_dict = trainer.valid(valid_loader, epoch)
    t4 = time.time()
    current_test_loss, current_test_acc, _test_performance, test_info_dict = trainer.test(test_loader, epoch)

    # record time spend
    t5 = time.time()
    time_spend.append(f'Epoch {epoch} spends {datetime.timedelta(seconds=t5-t2)}: training {datetime.timedelta(seconds=t3-t2)}, validing {datetime.timedelta(seconds=t4-t3)}, testing {datetime.timedelta(seconds=t5-t4)}\n')
    t1 = time.time()
    time_spend.append(f'\n\nRunning all code: {datetime.timedelta(seconds=t1-t0)}')
    with open(os.path.join(save_at, 'time_spend.txt'), 'w') as f:
        f.write('\n'.join(time_spend))

    # save model
    trainer.check_early_stop(current_valid_acc)


In [None]:
time_spend = []
t0 = time.time()

performance = pd.DataFrame(
    columns=['epoch', 'train_acc_l', 'test_acc_l', 'train_acc_d', 'test_acc_d', 
    'train_precision_mel', 'train_recall_mel', 'test_precision_mel', 'test_recall_mel',
    'train_precision_rhythm', 'train_recall_rhythm', 'test_precision_rhythm', 'test_recall_rhythm',
    'train_precision_harm', 'train_recall_harm', 'test_precision_harm', 'test_recall_harm']
)


## Evaluation

In [None]:
# After training, load the best model and test on each mvt 
best_epoch = [int(_.split('-best.pt')[0].split('epoch')[-1]) for _ in os.listdir(os.path.join(save_at, 'model')) if _.endswith('best.pt')]
if best_epoch:
    best_epoch = max(best_epoch)
else:
    best_epoch = epoch_end - 1
model.load_state_dict(torch.load(os.path.join(
    save_at, 'model', f"epoch{str(best_epoch)}-best.pt"
)))


### Evaluating on orchestration dataset (test set)

In [None]:
os.makedirs(f'{save_at}/test_on/confusion_matrix', exist_ok=True)
tester = MySingleClassTrainer(
    model = model,
    optimizer=optimizer, 
    scheduler=None,
    epoch_beg=epoch_beg,
    epoch_end=epoch_end, 
    loss_func=loss_func,
    save_at=save_at, 
    is_debug=is_debug, 
    early_stopping=trainer.early_stopping,
    device=device, threshold=threshold, 
    freeze=True,
    do_save_current_stage=False,
    return_performance=True
)
result = []
for apdx in range(18):
    test_on_set = PianoRollsDataset(
        meta_csv_file,
        test=True,
        test_piece=[apdx],
        context=CONTEXT_MEASURES,
        other_inst=OTHER_INST,
        blend=BLEND_MODE,
        k=k,
        add_inst=add_inst,
        add_barlines=add_barlines,
        add_beats=add_beats,
        add_rest=add_rest,
        add_replaymtx=add_replaymtx,
        converted_path=converted_path,
        ds='orchestration',
    )
    test_on_loader = DataLoader(test_on_set, batch_size=BATCH_SIZE, shuffle=False)
    tester.piece_name = apdx
    class_count, class_weight = get_class_weight_for_perf(test_on_set)
    tester.class_count = class_count
    tester.class_weight = class_weight
    for x,y in test_on_loader: break
    tester.model.eval()
    _,_,performance, _ = tester.test(test_on_loader, epoch=0)
    performance.update({'piece_name': apdx})
    performance.update({'stage': 'train' if apdx in test_pieces_idx else 'test'})
    result.append(performance)
    performance['cm_texture_new'] = performance['cm_texture'] / performance['cm_texture'].sum(axis=1, keepdims=True)
    print_confusion_matrix(performance['cm_texture_new'], 
                        f"{apdx}_texture",
                        f'{save_at}/test_on', show=False)
    np.save(f"{save_at}/test_on/orch-{apdx}.npy", performance['cm_texture_new'])
tmp_df = pd.DataFrame(result)[PERFORMANCE_DF_COLS]
tmp_df.to_csv(f'{save_at}/test_on/orch_performance.csv')


### Test on S3 dataset

In [None]:
os.makedirs(f'{save_at}/test_on/confusion_matrix', exist_ok=True)
tester = MySingleClassTrainer(
    model = model,
    optimizer=optimizer, 
    scheduler=None,
    epoch_beg=epoch_beg,
    epoch_end=epoch_end, 
    loss_func=loss_func,
    save_at=save_at, 
    is_debug=is_debug, 
    early_stopping=trainer.early_stopping,
    device=device, threshold=threshold, 
    freeze=True,
    do_save_current_stage=False,
    return_performance=True
)

result = []
test_on_set = PianoRollsDataset(
    meta_csv_file_s3,
    test=True,
    test_piece=list(range(16)),
    context=CONTEXT_MEASURES,
    other_inst=OTHER_INST,
    blend=BLEND_MODE,
    k=k,
    add_inst=add_inst,
    add_barlines=add_barlines,
    add_replaymtx=add_replaymtx,
    converted_path=converted_path,
    ds='s3',
)
test_on_loader = DataLoader(test_on_set, batch_size=BATCH_SIZE, shuffle=False)
# tester.piece_name = apdx
class_count, class_weight = get_class_weight(test_on_set)
tester.class_count = class_count
tester.class_weight = class_weight
for x,y in test_on_loader: break
tester.model.eval()
_,_,performance, _ = tester.test(test_on_loader, epoch=0)
performance.update({'piece_name': 's3_all'})
performance.update({'stage': 'test'})
result.append(performance)
performance['cm_texture_new'] = performance['cm_texture'] / performance['cm_texture'].sum(axis=1, keepdims=True)
print_confusion_matrix(performance['cm_texture_new'], 
                    f"s3_all_texture",
                    f'{save_at}/test_on', show=False)
np.save(f"{save_at}/test_on/s3_all.npy", performance['cm_texture_new'])
tmp_df = pd.DataFrame(result)[PERFORMANCE_DF_COLS]
tmp_df.to_csv(f'{save_at}/test_on/s3_performance_all.csv')