In [1]:
from music21 import pitch
from pprint import pprint, pformat

import os
import random
import csv
import math
import time
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F

## Utils

In [3]:
class PianoPiece:
    def __init__(self, notes=[], fingers=[], intervals=[], accidentals=[], ids=[], duration=[], file_name=""):
        self.notes = notes
        self.fingers = fingers
        self.intervals = intervals
        self.accidentals = accidentals
        self.ids = ids
        self.durations = duration
        self.file_name = file_name


class PianoFingeringDataset(Dataset):
    def __init__(self, filenames, data_dir, aug=False):
        self.input_list, self.label_list, self.processed_data = prepare_inputs(filenames, data_dir, aug)

    def __len__(self):
        return len(self.input_list)

    def __getitem__(self, idx):
        x = torch.tensor(self.input_list[idx], dtype=torch.float32)
        y = torch.tensor(self.label_list[idx], dtype=torch.int64)  
        return x, y


In [2]:
def reorganize_fingers(filename, data_dir):
    """
    Reorganizes pitch data for chords in the right hand.

    For each chord, the notes (and its rows) are sorted in ascending order.
    Example: Given a chord G4, E4, and C4 is reordered as C4, E4, G4.
    """

    data = []
    pre_onset = 0
    pre_row = None

    filepath = data_dir + filename
    df = pd.read_csv(filepath, header=None)
    for index, row in df.iterrows():
        min_index = 9999
        data = []
        current_onset = float(row[1])
        current_finger = int(str(row.iloc[-1]).split('_')[0])  

        if current_finger > 0: # right hand
            if math.isclose(pre_onset, current_onset, rel_tol=1e-4):
                for i in range(index - 1, index + 4, 1): # finding another notes of chord 
                    if i < len(df):
                        if df.iloc[i, 1] == current_onset and str(df.iloc[i, -2]) == '0':
                            if min_index > df.iloc[i, 0]:
                                min_index = df.iloc[i, 0]
                            data.append(df.iloc[i].tolist())
                
                data_sort = sorted(data.copy(), key = lambda x: float(pitch.Pitch(x[3]).ps))
                for idx, value in enumerate(data_sort):
                    value[0] = min_index + idx # changing the idx of note
                
                for i in range(min_index, len(data_sort) + min_index, 1):
                    df.iloc[i] = data_sort[i - min_index]
               
        pre_onset = current_onset
        pre_row = row.tolist()
    
    return df


In [5]:
interval_to_midi = {
    # "Unison": 0,
    # "Minor Second": 1,
    # "Major Second": 2,
    # "Minor Third": 3,
    # "Major Third": 4,
    # "Perfect Fourth": 5,
    # "Tritone": 6,
    # "Perfect Fifth": 7,
    # "Minor Sixth": 8,
    # "Major Sixth": 9,
    # "Minor Seventh": 10,
    # "Major Seventh": 11,
    "Octave": 12
}


def pass_bounds(notes):
    surpass = False
    for n in notes:
        if not (n == 0 or (21 <= n < 108)):
            surpass = True
    return surpass


def interval_symmetry(piece, interval):
    """
    Generates symmetrical piece by applying interval shifts across multiple octaves.
    """
    pieces = []
    
    octaves = [-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    for octave in octaves:
        new_notes = [(n + (octave * interval)) if n != 0 else 0 for n in piece.notes]
        if not pass_bounds(new_notes):
            pieces.append(PianoPiece(new_notes, 
                                     piece.fingers, 
                                     piece.intervals, 
                                     piece.accidentals, 
                                     piece.ids, 
                                     piece.durations,
                                     piece.file_name))
    return pieces


In [7]:
def load_piano_piece(filename, data_dir, aug=False):
    df = reorganize_fingers(filename, data_dir)
    df.columns = ["ID", "Onset", "Offset", "PitchName", "Column4", "Column5", "Channel", "Finger"]
    
    df["Onset"] = df["Onset"].astype(float)
    df["Offset"] = df["Offset"].astype(float)
    df["Finger"] = df["Finger"].astype(str).str.split('_').str[0].astype(int)
    df["Note"] = df["PitchName"].apply(lambda x: pitch.Pitch(x).ps)
    df["Accidental"] = df["PitchName"].apply(lambda x: int(pitch.Pitch(x).accidental is None)) # whether note is black(0) or white(1) 
    
    notes, fingers, accidentals, ids, durations = [], [], [], [], []
    pre_onset = 0
    
    for _, row in df.iterrows():
        if row["Finger"] > 0:
            notes.append(row["Note"])
            fingers.append(row["Finger"])
            accidentals.append(row["Accidental"])
            ids.append(int(row["ID"]))
            durations.append(round(row["Offset"] - row["Onset"], 2))
        pre_onset = row["Onset"]
    
    intervals = np.diff(np.array(notes, dtype=int)).tolist()
    piece = PianoPiece(notes, fingers, intervals, accidentals, ids, durations, filename)
    pieces = []
    
    if aug:
        for interval in interval_to_midi.values():
            res = interval_symmetry(piece, interval)
            pieces.extend(res)
        return pieces
    else:
        return [piece]


In [9]:
def split_files(data_dir, train_ratio, val_ratio, test_ratio):
    all_files = sorted(os.listdir(data_dir))
    random.seed(42)
    random.shuffle(all_files)  
    total_files = len(all_files)

    train_end = int(total_files * train_ratio)
    val_end = train_end + int(total_files * val_ratio)

    train_files = all_files[:train_end]
    val_files = all_files[train_end:val_end]
    test_files = all_files[val_end:]
    
    return train_files, val_files, test_files

In [177]:
def prepare_inputs(filenames, data_dir, aug=False):

    """
    Prepare inputs for the neural network.
    
    The input vector consists of:
    1. Fingering of the current note
    2. The semitone distance (midi) to the next note
    3. Whether the current note is a black key
    4. Whether the next note is a black key
    5. Duration of the current note

    The labels are fingers
    """

    inputs = []
    labels = []
    processed_data = {}
    
    for filename in tqdm(sorted(filenames), desc="Processing Files"):
        vector_list = []
        pieces = load_piano_piece(filename, data_dir, aug)
        
        for i in range(len(pieces)):
            vector_list.append([
                [f, il, bw_s, bw_e, dur]
                for f, il, bw_s, bw_e, dur in zip(
                    pieces[i].fingers[:-1],
                    pieces[i].intervals,
                    pieces[i].accidentals[:-1],
                    pieces[i].accidentals[1:],
                    pieces[i].durations[:-1]
                )
            ])
        processed_data[filename] = len(pieces)
        for i in range(len(vector_list)):
            inputs.extend(
                [l for l in slide_window_future_gen(vector_list[i], BLOCK_LENGTH, FUTURE_LENGTH)]
            )
            labels.extend(
                [f for f in pieces[i].fingers[BLOCK_LENGTH - FUTURE_LENGTH : -FUTURE_LENGTH]]
            )

        
    return inputs, labels, processed_data


In [11]:
def create_dataloaders(data_dir, batch_size, train_ratio, val_ratio, test_ratio):
    train_files, val_files, test_files = split_files(data_dir, train_ratio, val_ratio, test_ratio)

    print(f"Train set {len(train_files)}")
    print(f"Vaidation set {len(val_files)}")
    print(val_files)
    print(f"Test set {len(test_files)}")
    print(test_files)
    print()
    
    train_dataset = PianoFingeringDataset(train_files, data_dir, aug=True)
    val_dataset = PianoFingeringDataset(val_files, data_dir)
    test_dataset = PianoFingeringDataset(test_files, data_dir)

    len_train = 0
    
    for i in train_dataset.processed_data.values():
        len_train += int(i)
    print(f"Train set after {len_train}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, train_files, val_files, test_files

## Training

In [178]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, block_length, future_length):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True, num_layers=3)
        self.lambda_layer_idx = block_length - future_length - 1
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(hidden_size * 2, output_size)  
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)  
        selected_output = lstm_out[:, self.lambda_layer_idx, :]  
        logits = self.fc(selected_output)  
        probabilities = self.softmax(logits)
        return probabilities



In [8]:
def slide_window_future_gen(input_list, window_size, future_size):
    for start in range(len(input_list) - window_size + 1):
        full_list = input_list[start : start + window_size]
        for i in range(window_size-future_size, window_size):
            full_list[i][0] = 0
        yield full_list

In [179]:
def train_model(model, device, train_loader, val_loader, num_epochs, name=None):
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        start_time = time.time()  
        
        model.train()
        epoch_loss = 0.0
        correct = 0
        total = 0

        for idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            labels = labels - 1 

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss /= total
        accuracy = correct / total
        
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for idx, (inputs, labels) in enumerate(val_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)
                labels = labels - 1 
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)

        val_loss /= val_total
        val_accuracy = val_correct / val_total

        epoch_time = time.time() - start_time

        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"Loss: {epoch_loss:.4f}, "
              f"Accuracy: {accuracy:.4f}, "
              f"Time: {epoch_time:.2f} sec ")
        
        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"Val Loss: {val_loss:.4f}, "
              f"Validation Accuracy: {val_accuracy:.4f}")
        print()

    if name is not None:
        print(f"The model is saved as {name}")
        torch.save(model.state_dict(), name)
    return model


In [180]:
BLOCK_LENGTH = 11
FUTURE_LENGTH = 5
N_HIDDEN=128
FINGER_SIZE = 5
DATA_DIR = "/kaggle/input/pig-dataset/"
BATCH_SIZE = 256
NUM_EPOCHS = 20
TRAIN_RATIO = 0.7
VAL_RATIO = 0.25
TEST_RATIO = 0.05
INPUT_SIZE = 5

block_future = [(11, 5)]


def create_bi_direction_with_future_model(block_length, future_length):
    return LSTM(input_size=INPUT_SIZE,  # Assuming input features are scalar
                                       hidden_size=N_HIDDEN,
                                       output_size=FINGER_SIZE,
                                       block_length=block_length,
                                       future_length=future_length)



# block_future = [(7, 5), (8, 5), (9, 5), (10, 5), (11, 5),
#                 (7, 4), (9, 4),
#                 (11, 6), (11, 7), (11, 8),
#                 (12, 8), (12, 9), (12, 4),
#                 (15, 10), (15, 12),
#                 (13, 8), (9, 6), (14, 10)
#                ]
    
train_loader, val_loader, test_loader, train_files, val_files, test_files = create_dataloaders(DATA_DIR, BATCH_SIZE, TRAIN_RATIO, VAL_RATIO, TEST_RATIO)


Train set 216
Vaidation set 77
['072-1_fingering.csv', '014-1_fingering.csv', '004-5_fingering.csv', '066-1_fingering.csv', '065-1_fingering.csv', '076-1_fingering.csv', '018-3_fingering.csv', '049-2_fingering.csv', '013-6_fingering.csv', '128-2_fingering.csv', '124-1_fingering.csv', '043-1_fingering.csv', '015-7_fingering.csv', '112-1_fingering.csv', '024-5_fingering.csv', '106-1_fingering.csv', '014-6_fingering.csv', '059-1_fingering.csv', '028-3_fingering.csv', '011-5_fingering.csv', '109-1_fingering.csv', '035-1_fingering.csv', '005-5_fingering.csv', '044-1_fingering.csv', '129-1_fingering.csv', '097-1_fingering.csv', '052-1_fingering.csv', '015-6_fingering.csv', '047-2_fingering.csv', '013-3_fingering.csv', '019-4_fingering.csv', '020-1_fingering.csv', '011-3_fingering.csv', '019-7_fingering.csv', '074-2_fingering.csv', '042-1_fingering.csv', '115-2_fingering.csv', '145-1_fingering.csv', '020-5_fingering.csv', '012-7_fingering.csv', '138-1_fingering.csv', '121-1_fingering.csv', '0

Processing Files: 100%|██████████| 216/216 [00:34<00:00,  6.29it/s]
Processing Files: 100%|██████████| 77/77 [00:10<00:00,  7.00it/s]
Processing Files: 100%|██████████| 16/16 [00:02<00:00,  7.16it/s]

<class 'dict'>
Train set after 1072





In [181]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

for i in block_future:
    BLOCK_LENGTH = i[0]
    FUTURE_LENGTH = i[1]
    model = create_bi_direction_with_future_model(BLOCK_LENGTH, FUTURE_LENGTH)
    print(f"Model is created with BLOCK_LENGTH = {BLOCK_LENGTH} and FUTURE_LENGTH = {FUTURE_LENGTH}")
    train_model(model, device, train_loader, val_loader, NUM_EPOCHS, f"lstm_b{BLOCK_LENGTH}_f{FUTURE_LENGTH}.pt")
    print("__________________________________________________________")
    print()

cuda
Model is created with BLOCK_LENGTH = 11 and FUTURE_LENGTH = 5
Epoch [1/20], Loss: 1.2369, Accuracy: 0.6631, Time: 11.26 sec 
Epoch [1/20], Val Loss: 1.2176, Validation Accuracy: 0.6819

Epoch [2/20], Loss: 1.1606, Accuracy: 0.7434, Time: 10.99 sec 
Epoch [2/20], Val Loss: 1.1823, Validation Accuracy: 0.7190

Epoch [3/20], Loss: 1.1267, Accuracy: 0.7780, Time: 11.00 sec 
Epoch [3/20], Val Loss: 1.1802, Validation Accuracy: 0.7218

Epoch [4/20], Loss: 1.1033, Accuracy: 0.8025, Time: 11.06 sec 
Epoch [4/20], Val Loss: 1.1732, Validation Accuracy: 0.7305

Epoch [5/20], Loss: 1.0919, Accuracy: 0.8134, Time: 11.25 sec 
Epoch [5/20], Val Loss: 1.1788, Validation Accuracy: 0.7232

Epoch [6/20], Loss: 1.0799, Accuracy: 0.8257, Time: 10.97 sec 
Epoch [6/20], Val Loss: 1.1662, Validation Accuracy: 0.7373

Epoch [7/20], Loss: 1.0720, Accuracy: 0.8335, Time: 11.02 sec 
Epoch [7/20], Val Loss: 1.1722, Validation Accuracy: 0.7316

Epoch [8/20], Loss: 1.0661, Accuracy: 0.8393, Time: 11.21 sec 
Ep

## Evaluating

In [182]:
def group_files_by_prefix(test_files):
    file_dict = {}
    for test_file in test_files:
        pre_fix = test_file.split('-')[0]
        if pre_fix in file_dict:
            file_dict[pre_fix].append(test_file)
        else:
            temp_list = [test_file]
            file_dict[pre_fix] = temp_list
    return file_dict

In [183]:
def update_state_with_prediction(old_state, finger_pred, new_vec, future_size):
    pred = old_state[-future_size]
    pred[0] = finger_pred  # Update the predicted finger

    # Updating the state with the new vector as a tensor
    new_state = torch.tensor([0] + new_vec, dtype=torch.float32)
    old_state[-future_size] = pred
    return old_state[1:] + [new_state]  


In [184]:
def prepare_test_inputs(filename, data_dir):
    inputs = []
    labels = []

    pieces = load_piano_piece(filename, data_dir)
    vector = [
        [il, bw_s, bw_e, dur]
        for il, bw_s, bw_e, dur in zip(
            pieces[0].intervals,
            pieces[0].accidentals[:-1],
            pieces[0].accidentals[1:],
            pieces[0].durations[:-1]
        )
    ]
        
    inputs.append(vector)
    labels.append(pieces[0].fingers)
    
    return inputs, labels, pieces[0].ids

In [185]:
def predict_fingerings(input_list, label_list, model, device):
    model.to(device)
    model.eval()
    results = []

    with torch.no_grad():
        for test_vector, test_finger in zip(input_list, label_list):
            init_state_b = [
                torch.tensor([test_finger[i]] + test_vector[i], dtype=torch.float32)
                for i in range(BLOCK_LENGTH - FUTURE_LENGTH)
            ]
            init_state_a = [
                torch.tensor([0] + test_vector[i], dtype=torch.float32)
                for i in range(BLOCK_LENGTH - FUTURE_LENGTH, BLOCK_LENGTH)
            ]

            init_state = init_state_b + init_state_a
            num_intervals = len(test_vector)
            temp_finger_res = []

            for test_step in range(0, num_intervals - BLOCK_LENGTH + 1):
                np_init_state = (
                    torch.stack(init_state)
                    .view(-1, BLOCK_LENGTH, INPUT_SIZE)
                    .to(device)
                )
                pred_prob = model(np_init_state)
                finger_pred = torch.argmax(pred_prob, dim=1).item() + 1
                temp_finger_res.append(finger_pred)

                if test_step < num_intervals - BLOCK_LENGTH - 1:
                    next_vector = test_vector[test_step + BLOCK_LENGTH]
                    init_state = update_state_with_prediction(
                        init_state, finger_pred, next_vector, FUTURE_LENGTH
                    )

            temp_finger_res = (test_finger[: BLOCK_LENGTH - FUTURE_LENGTH] + temp_finger_res + test_finger[-FUTURE_LENGTH:])
            results.append(temp_finger_res)

    return results


In [186]:
def evaluate_fingering(test_files, model, device):
    file_dict = group_files_by_prefix(test_files)
    total_correct = 0
    total_predictions = 0

    for hmm_res_file in test_files:
        pre_fix = hmm_res_file.split('-')[0]
        
        if pre_fix in file_dict:
            test_input_list, test_label_list, test_id_list = prepare_test_inputs(hmm_res_file, DATA_DIR\)
            predicted_fingerings = predict_fingerings(test_input_list, test_label_list, model, device)
            
            flat_pred = [pred for pred in predicted_fingerings[0]]
            flat_label = [gt for gt in test_label_list[0]]

            correct = sum(p == gt for p, gt in zip(flat_pred, flat_label))
            total_correct += correct
            total_predictions += len(flat_label)
            
            file_accuracy = correct / len(flat_label) if len(flat_label) > 0 else 0
            print(f"File: {hmm_res_file} | Accuracy: {file_accuracy:.4f}")

    overall_accuracy = total_correct / total_predictions if total_predictions > 0 else 0
    print(f"Overall Categorical Accuracy: {overall_accuracy:.4f}")


In [187]:
model = create_bi_direction_with_future_model(BLOCK_LENGTH, FUTURE_LENGTH)
model.load_state_dict(torch.load("/kaggle/working/lstm_b11_f5.pt", weights_only=True))

# for dirname, _, filenames in os.walk('/kaggle/working/new_csv_train'):
#     for filename in filenames: 
#         test_files.append(filename)
        
evaluate_fingering(test_files,  model, device)

File: 115-1_fingering.csv | Accuracy: 0.7237
File: 024-6_fingering.csv | Accuracy: 0.7063
File: 023-3_fingering.csv | Accuracy: 0.7055
File: 012-3_fingering.csv | Accuracy: 0.6562
File: 004-8_fingering.csv | Accuracy: 0.7073
File: 005-1_fingering.csv | Accuracy: 0.6369
File: 076-2_fingering.csv | Accuracy: 0.7101
File: 011-6_fingering.csv | Accuracy: 0.7770
File: 128-1_fingering.csv | Accuracy: 0.8465
File: 013-1_fingering.csv | Accuracy: 0.6118
File: 016-3_fingering.csv | Accuracy: 0.6000
File: 023-6_fingering.csv | Accuracy: 0.7603
File: 026-1_fingering.csv | Accuracy: 0.8703
File: 029-1_fingering.csv | Accuracy: 0.7579
File: 004-1_fingering.csv | Accuracy: 0.5607
File: 013-7_fingering.csv | Accuracy: 0.6184
Overall Categorical Accuracy: 0.7186
