In [None]:
!pip install optuna torch pandas scikit-learn # run if modules are not installed



In [2]:
import torch
import torch.nn as nn
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
from fractions import Fraction

# Load the dataset from the CSV file (ensure output.csv is uploaded to your Colab environment)
df = pd.read_csv('output_FullDataset.csv')

# Inspect the first few rows and columns to verify column names
print(df.head())
print("Columns:", df.columns)

# Build a vocabulary mapping from note tokens to unique integers (reserve 0 for padding)
def build_note_vocab(dataframe):
    notes_set = set()
    # Only use rows where 'Type' equals "Note"
    for note_str in dataframe[dataframe['Type'] == "Note"]['Note']:
        notes_set.update(note_str.split())
    note2idx = {note: idx + 1 for idx, note in enumerate(sorted(notes_set))}
    return note2idx

note2idx = build_note_vocab(df)
print("Note Vocabulary mapping:", note2idx)

# Build a vocabulary mapping for chord labels from the 'Active Chord' column
def build_chord_vocab(dataframe):
    chords = dataframe[dataframe['Type'] == "Note"]['Active Chord'].dropna().unique()
    chords_str = sorted([str(chord) for chord in chords])
    chord2idx = {chord: idx for idx, chord in enumerate(chords_str)}
    return chord2idx

chord2idx = build_chord_vocab(df)
print("Chord mapping:", chord2idx)

                    File           Part  Measure          Note  Octave  \
0  1974%20Blues.musicxml  MusicXML Part        1  F2.A2.C3.E-3     NaN   
1  1974%20Blues.musicxml  MusicXML Part        1             D     3.0   
2  1974%20Blues.musicxml  MusicXML Part        1             D     3.0   
3  1974%20Blues.musicxml  MusicXML Part        1             E     3.0   
4  1974%20Blues.musicxml  MusicXML Part        1             B     2.0   

  Duration   Type Active Chord Chord Duration  
0      0.0  Chord           F7            2.0  
1      0.5   Note           F7            2.0  
2      0.5   Note           F7            2.0  
3      0.5   Note           F7            2.0  
4      0.5   Note           F7            2.0  
Columns: Index(['File', 'Part', 'Measure', 'Note', 'Octave', 'Duration', 'Type',
       'Active Chord', 'Chord Duration'],
      dtype='object')
Note Vocabulary mapping: {'A': 1, 'A#': 2, 'A-': 3, 'B': 4, 'B#': 5, 'B-': 6, 'B--': 7, 'C': 8, 'C#': 9, 'C-': 10, 'D': 11

In [3]:
# Enharmonic equivalents dictionary: here '-' is used to denote flats.
enharmonic_equivalents = {
    'C#': 'D-', 'D-': 'C#',
    'D#': 'E-', 'E-': 'D#',
    'F#': 'G-', 'G-': 'F#',
    'G#': 'A-', 'A-': 'G#',
    'A#': 'B-', 'B-': 'A#'
}

def augment_notes(note_str, p=0.5):
    """
    With probability p, replace a note token with its enharmonic equivalent.
    """
    tokens = note_str.split()
    augmented_tokens = []
    for token in tokens:
        if token in enharmonic_equivalents and random.random() < p:
            augmented_tokens.append(enharmonic_equivalents[token])
        else:
            augmented_tokens.append(token)
    return " ".join(augmented_tokens)

def tokenize_notes(note_str, mapping):
    """
    Convert a space-separated note string into a list of integers using the note mapping.
    """
    return [mapping[note] for note in note_str.split() if note in mapping]

def process_octaves(octave_input, max_seq_length):
    """
    Process octave information.
    - If the input is a string, assume it is space-separated and split it.
    - If it's not a string (e.g., a float or int), assume it's a single value and replicate it.
    Then pad/truncate to max_seq_length.
    """
    if isinstance(octave_input, str):
        tokens = [int(o) for o in octave_input.split()]
    else:
        tokens = [int(octave_input)]

    # Pad or truncate the list to max_seq_length
    if len(tokens) < max_seq_length:
        tokens = tokens + [0] * (max_seq_length - len(tokens))
    else:
        tokens = tokens[:max_seq_length]
    return tokens

def process_note_durations(duration_str, max_seq_length):
    """
    Convert a space-separated duration string into a list of floats and pad/truncate to max_seq_length.
    This version handles fractional durations (e.g., "1/3").
    """
    tokens = []
    for d in duration_str.split():
        try:
            tokens.append(float(d))
        except ValueError:
            try:
                tokens.append(float(Fraction(d)))
            except Exception:
                tokens.append(0.0)
    if len(tokens) < max_seq_length:
        tokens = tokens + [0.0] * (max_seq_length - len(tokens))
    else:
        tokens = tokens[:max_seq_length]
    return tokens

In [4]:
from fractions import Fraction

def convert_to_float(value):
    """
    Converts a string value to float.
    If the string is a fraction (e.g., '2/3'), it converts it appropriately.
    """
    try:
        return float(value)
    except ValueError:
        return float(Fraction(value))

class MusicDataset(Dataset):
    def __init__(self, dataframe, note_mapping, chord_mapping, max_seq_length=32, augment=False):
        # Filter rows to only include those where 'Type' equals "Note" and chord label is not NaN
        self.data = dataframe[(dataframe['Type'] == "Note") & (dataframe['Active Chord'].notna())].reset_index(drop=True)
        self.note_mapping = note_mapping
        self.chord_mapping = chord_mapping
        self.max_seq_length = max_seq_length
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Process the note string (with optional augmentation)
        note_str = row['Note']
        if self.augment:
            note_str = augment_notes(note_str)
        tokens = tokenize_notes(note_str, self.note_mapping)
        if len(tokens) < self.max_seq_length:
            tokens = tokens + [0] * (self.max_seq_length - len(tokens))
        else:
            tokens = tokens[:self.max_seq_length]

        # Process octave and note duration sequences
        octaves = process_octaves(row['Octave'], self.max_seq_length)
        note_durations = process_note_durations(row['Duration'], self.max_seq_length)

        # Process chord label and chord duration target
        chord_label = self.chord_mapping[str(row['Active Chord'])]
        chord_duration = convert_to_float(row['Chord Duration'])

        # Convert all data into tensors
        tokens_tensor = torch.tensor(tokens, dtype=torch.long)
        octaves_tensor = torch.tensor(octaves, dtype=torch.long)
        note_durations_tensor = torch.tensor(note_durations, dtype=torch.float)
        chord_label_tensor = torch.tensor(chord_label, dtype=torch.long)
        chord_duration_tensor = torch.tensor(chord_duration, dtype=torch.float)

        # Return tuple: (inputs, (classification target, regression target))
        return (tokens_tensor, octaves_tensor, note_durations_tensor), (chord_label_tensor, chord_duration_tensor)

# Create the dataset and DataLoader; set augment=True to enable augmentation.
dataset = MusicDataset(df, note2idx, chord2idx, max_seq_length=32, augment=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Test by retrieving one batch
for batch in dataloader:
    inputs, targets = batch
    tokens, octaves, note_durations = inputs
    chord_labels, chord_durations = targets
    print("Tokens shape:", tokens.shape)
    print("Octaves shape:", octaves.shape)
    print("Note durations shape:", note_durations.shape)
    print("Chord labels shape:", chord_labels.shape)
    print("Chord durations shape:", chord_durations.shape)
    break


Tokens shape: torch.Size([32, 32])
Octaves shape: torch.Size([32, 32])
Note durations shape: torch.Size([32, 32])
Chord labels shape: torch.Size([32])
Chord durations shape: torch.Size([32])


In [5]:
import torch
import torch.nn as nn

class ChordPredictor(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, num_classes, max_seq_length, num_octaves=10):
        """
        num_octaves: maximum number of octave categories (adjust if your octave range is larger)
        """
        super(ChordPredictor, self).__init__()
        # Embedding for note tokens
        self.note_embed = nn.Embedding(vocab_size, embed_dim)
        # Embedding for octave values (assumes octave values are small integers)
        self.octave_embed = nn.Embedding(num_octaves, embed_dim)
        # Linear projection for note durations (continuous values)
        self.duration_linear = nn.Linear(1, embed_dim)

        # Learnable positional encoding
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_length, embed_dim))

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Additional fully connected layers after transformer pooling
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        
        # Two output heads:
        # - For chord classification
        self.fc_class = nn.Linear(hidden_dim // 2, num_classes)
        # - For chord duration regression
        self.fc_duration = nn.Linear(hidden_dim // 2, 1)

    def forward(self, tokens, octaves, note_durations):
        """
        tokens: LongTensor of shape [batch_size, seq_length]
        octaves: LongTensor of shape [batch_size, seq_length]
        note_durations: FloatTensor of shape [batch_size, seq_length]
        """
        token_emb = self.note_embed(tokens)           # [B, L, embed_dim]
        octave_emb = self.octave_embed(octaves)         # [B, L, embed_dim]
        duration_emb = self.duration_linear(note_durations.unsqueeze(-1))  # [B, L, embed_dim]

        # Sum embeddings and add positional encoding
        x = token_emb + octave_emb + duration_emb + self.pos_embedding  # [B, L, embed_dim]

        # Transformer expects input of shape [L, B, embed_dim]
        x = x.permute(1, 0, 2)
        x = self.transformer_encoder(x)

        # Use the first token's output as a pooled representation
        pooled = x[0]  # [B, embed_dim]

        # Additional layers for further processing
        x = self.fc1(pooled)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)

        # Final output heads
        chord_logits = self.fc_class(x)              # [B, num_classes]
        chord_duration = self.fc_duration(x).squeeze(-1)  # [B]
        return chord_logits, chord_duration

# Hyperparameters (adjust these values based on your dataset)
vocab_size = len(note2idx) + 1   # +1 for padding index 0
embed_dim = 64
num_heads = 4
hidden_dim = 128
num_layers = 2
num_classes = len(chord2idx)     # Number of unique chords
max_seq_length = 32
num_octaves = 10  # Adjust if your octave range differs

# Initialize the model
model = ChordPredictor(vocab_size, embed_dim, num_heads, hidden_dim, num_layers, num_classes, max_seq_length, num_octaves)
print(model)



ChordPredictor(
  (note_embed): Embedding(26, 64)
  (octave_embed): Embedding(10, 64)
  (duration_linear): Linear(in_features=1, out_features=64, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=128, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Li



In [6]:
import optuna
import torch.optim as optim

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def objective(trial):
    # Suggest embedding dimension from a fixed list
    embed_dim = trial.suggest_categorical("embed_dim", [32, 64, 128])

    # Choose valid num_heads that evenly divide embed_dim from the available list
    valid_num_heads = [h for h in [2, 4, 6, 8] if embed_dim % h == 0]
    num_heads = trial.suggest_categorical("num_heads", valid_num_heads)

    hidden_dim = trial.suggest_int("hidden_dim", 64, 256, step=64)
    num_layers = trial.suggest_int("num_layers", 1, 4)
    learning_rate = trial.suggest_loguniform("lr", 1e-4, 1e-2)

    # Create model instance with suggested hyperparameters
    model = ChordPredictor(vocab_size, embed_dim, num_heads, hidden_dim, num_layers,
                           num_classes, max_seq_length, num_octaves)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion_class = nn.CrossEntropyLoss()
    criterion_reg = nn.MSELoss()

    # Train for a fixed number of epochs (adjust as needed)
    num_epochs = 150
    model.train()
    for epoch in range(num_epochs):
        print("epoch no:", epoch)
        total_loss = 0.0
        for batch in dataloader:
            (tokens, octaves, note_durations), (chord_labels, chord_durations) = batch
            tokens = tokens.to(device)
            octaves = octaves.to(device)
            note_durations = note_durations.to(device)
            chord_labels = chord_labels.to(device)
            chord_durations = chord_durations.to(device)

            optimizer.zero_grad()
            chord_logits, chord_duration_pred = model(tokens, octaves, note_durations)
            loss_class = criterion_class(chord_logits, chord_labels)
            loss_reg = criterion_reg(chord_duration_pred, chord_durations)
            loss = loss_class + loss_reg
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print("epoch no:", epoch, "avg loss", avg_loss)
        trial.report(avg_loss, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return avg_loss


In [7]:
'''study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=20)  # Increase n_trials for a more thorough search if desired

print("Best trial:")
trial = study.best_trial
print("  Loss: {:.4f}".format(trial.value))
print("  Params:")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))
'''

'study = optuna.create_study(direction="minimize")\nstudy.optimize(objective, n_trials=20)  # Increase n_trials for a more thorough search if desired\n\nprint("Best trial:")\ntrial = study.best_trial\nprint("  Loss: {:.4f}".format(trial.value))\nprint("  Params:")\nfor key, value in trial.params.items():\n    print("    {}: {}".format(key, value))\n'

In [8]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# Split the dataset into training and validation sets (80% training, 20% validation)
train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

# Create DataLoaders for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print("Number of training samples:", len(train_dataset))
print("Number of validation samples:", len(val_dataset))

Number of training samples: 22019
Number of validation samples: 5505


In [40]:
import torch
import torch.nn as nn
import torch.optim as optim
from fractions import Fraction

# Define a helper function to convert fractional strings to float
def convert_to_float(value):
    try:
        return float(value)
    except ValueError:
        return float(Fraction(value))

# --- Best Trial Hyperparameters ---
# These parameters are from your best trial:
embed_dim = 32
num_heads = 4
hidden_dim = 256 # was 192 before
num_layers = 4
learning_rate = 0.0007428326378443678
num_epochs = 325

# Ensure device is set correctly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model using the best trial hyperparameters.
# Ensure that vocab_size, num_classes, max_seq_length, and num_octaves are defined.
model = ChordPredictor(vocab_size, embed_dim, num_heads, hidden_dim, num_layers, 
                       num_classes, max_seq_length, num_octaves)
model.to(device)

# Define the optimizer and loss functions
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion_class = nn.CrossEntropyLoss()
criterion_reg = nn.MSELoss()

# --- Training Loop ---
print("Starting training...\n")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch in train_dataloader:
        # Unpack batch data
        (tokens, octaves, note_durations), (chord_labels, chord_durations) = batch
        tokens = tokens.to(device)
        octaves = octaves.to(device)
        note_durations = note_durations.to(device)
        chord_labels = chord_labels.to(device)
        chord_durations = chord_durations.to(device)
        
        optimizer.zero_grad()
        chord_logits, chord_duration_pred = model(tokens, octaves, note_durations)
        loss_class = criterion_class(chord_logits, chord_labels)
        loss_reg = criterion_reg(chord_duration_pred, chord_durations)
        loss = loss_class + loss_reg
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")




Starting training...

Epoch 1/325, Loss: 4.9663
Epoch 2/325, Loss: 4.7397
Epoch 3/325, Loss: 4.7127
Epoch 4/325, Loss: 4.6615
Epoch 5/325, Loss: 4.6304
Epoch 6/325, Loss: 4.6156
Epoch 7/325, Loss: 4.5811
Epoch 8/325, Loss: 4.5499
Epoch 9/325, Loss: 4.5352
Epoch 10/325, Loss: 4.4989
Epoch 11/325, Loss: 4.4652
Epoch 12/325, Loss: 4.4396
Epoch 13/325, Loss: 4.4276
Epoch 14/325, Loss: 4.4103
Epoch 15/325, Loss: 4.3961
Epoch 16/325, Loss: 4.3838
Epoch 17/325, Loss: 4.3740
Epoch 18/325, Loss: 4.3598
Epoch 19/325, Loss: 4.3545
Epoch 20/325, Loss: 4.3488
Epoch 21/325, Loss: 4.3406
Epoch 22/325, Loss: 4.3324
Epoch 23/325, Loss: 4.3275
Epoch 24/325, Loss: 4.3220
Epoch 25/325, Loss: 4.3079
Epoch 26/325, Loss: 4.2999
Epoch 27/325, Loss: 4.2978
Epoch 28/325, Loss: 4.2851
Epoch 29/325, Loss: 4.2862
Epoch 30/325, Loss: 4.2885
Epoch 31/325, Loss: 4.2671
Epoch 32/325, Loss: 4.2645
Epoch 33/325, Loss: 4.2692
Epoch 34/325, Loss: 4.2575
Epoch 35/325, Loss: 4.2445
Epoch 36/325, Loss: 4.2390
Epoch 37/325, L

In [41]:
# --- Validation Function ---
def validate_model(model, dataloader, device):
    model.eval()
    total_class_loss = 0.0
    total_reg_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    # Use sum reduction to later average the loss per sample
    criterion_class_val = nn.CrossEntropyLoss(reduction='sum')
    criterion_reg_val = nn.MSELoss(reduction='sum')
    
    with torch.no_grad():
        for batch in dataloader:
            (tokens, octaves, note_durations), (chord_labels, chord_durations) = batch
            tokens = tokens.to(device)
            octaves = octaves.to(device)
            note_durations = note_durations.to(device)
            chord_labels = chord_labels.to(device)
            chord_durations = chord_durations.to(device)
            
            chord_logits, chord_duration_pred = model(tokens, octaves, note_durations)
            
            class_loss = criterion_class_val(chord_logits, chord_labels)
            reg_loss = criterion_reg_val(chord_duration_pred, chord_durations)
            total_class_loss += class_loss.item()
            total_reg_loss += reg_loss.item()
            
            # Calculate accuracy for chord classification
            _, predicted = torch.max(chord_logits, 1)
            total_correct += (predicted == chord_labels).sum().item()
            total_samples += chord_labels.size(0)
    
    avg_class_loss = total_class_loss / total_samples
    avg_reg_loss = total_reg_loss / total_samples
    accuracy = total_correct / total_samples
    
    return avg_class_loss, avg_reg_loss, accuracy

# --- Run Validation ---
avg_class_loss, avg_reg_loss, accuracy = validate_model(model, val_dataloader, device)
print("\nValidation Results:")
print("Validation Classification Loss: {:.4f}".format(avg_class_loss))
print("Validation Regression Loss: {:.4f}".format(avg_reg_loss))
print("Validation Accuracy: {:.4f}".format(accuracy))


Validation Results:
Validation Classification Loss: 4.2371
Validation Regression Loss: 1.1956
Validation Accuracy: 0.1074


In [11]:
!pip install safetensors



In [44]:
from safetensors.torch import save_file, load_file

# ----- Save the model as a safetensors file -----
safetensors_path = 'transformersv2/chord_predictor.safetensors'
save_file(model.state_dict(), safetensors_path)
print(f"Model saved as safetensors in {safetensors_path}")


Model saved as safetensors in transformersv2/chord_predictor.safetensors


In [48]:
# Example encoding dictionaries (replace these with your actual mappings)
note2idx = {'A': 1, 'A#': 2, 'A-': 3, 'B': 4, 'B#': 5, 'B-': 6, 'B--': 7, 'C': 8, 'C#': 9, 'C-': 10, 'D': 11, 'D#': 12, 'D-': 13, 'E': 14, 'E#': 15, 'E-': 16, 'E--': 17, 'F': 18, 'F#': 19, 'F-': 20, 'G': 21, 'G#': 22, 'G-': 23, 'G--': 24, 'G---': 25}  # 0 reserved for padding
idx2note = {idx: note for note, idx in note2idx.items()}

chord2idx = {'A': 0, 'A#m7': 1, 'A-': 2, 'A-+': 3, 'A-/B': 4, 'A-/B-': 5, 'A-/E-': 6, 'A-/G-': 7, 'A-13': 8, 'A-6': 9, 'A-7': 10, 'A-7 add #11': 11, 'A-7 add #9': 12, 'A-9': 13, 'A-M13': 14, 'A-M13 alter #11': 15, 'A-M9': 16, 'A-dim': 17, 'A-m': 18, 'A-m/B-': 19, 'A-m/E-': 20, 'A-m11': 21, 'A-m7': 22, 'A-m7/B-': 23, 'A-m7/D-': 24, 'A-m9': 25, 'A-maj7': 26, 'A-maj7/B-': 27, 'A-sus': 28, 'A-sus add 7': 29, 'A-sus/B- add 7': 30, 'A/E': 31, 'A13': 32, 'A6': 33, 'A7': 34, 'A7 add #11': 35, 'A7 add #9': 36, 'A7 add b9': 37, 'A9': 38, 'A9 add #11': 39, 'AM13': 40, 'AM9': 41, 'Adim': 42, 'Am': 43, 'Am11': 44, 'Am7': 45, 'Am7 alter b5': 46, 'Am7/D': 47, 'Am7/G': 48, 'Am9': 49, 'Amaj7': 50, 'Ao7': 51, 'Asus': 52, 'B': 53, 'B add 9': 54, 'B-': 55, 'B-/A': 56, 'B-/A-': 57, 'B-/B': 58, 'B-/C': 59, 'B-/D': 60, 'B-/F': 61, 'B-13': 62, 'B-13 alter #9': 63, 'B-6': 64, 'B-7': 65, 'B-7 add #11': 66, 'B-7 add #9': 67, 'B-7/C': 68, 'B-7/D#': 69, 'B-9': 70, 'B-M9': 71, 'B-m': 72, 'B-m11': 73, 'B-m6': 74, 'B-m7': 75, 'B-m7/E-': 76, 'B-m7/F': 77, 'B-m7/G': 78, 'B-m9': 79, 'B-mM7': 80, 'B-maj7': 81, 'B-maj7/C': 82, 'B-o7': 83, 'B-sus add 7': 84, 'B-sus/C add 7': 85, 'B/D': 86, 'B/D-': 87, 'B/E': 88, 'B/E-': 89, 'B13': 90, 'B7': 91, 'B7 add #9': 92, 'B7 add b9': 93, 'B7/A': 94, 'B9': 95, 'B9 add #11': 96, 'B9/F#': 97, 'BM9': 98, 'Bdim': 99, 'Bm': 100, 'Bm/F#': 101, 'Bm11': 102, 'Bm13': 103, 'Bm6': 104, 'Bm7': 105, 'Bm7 alter b5': 106, 'Bm7/A': 107, 'Bm9': 108, 'BmM7': 109, 'Bmaj7': 110, 'Bmaj7 add #11': 111, 'Bmaj7/D#': 112, 'Bmaj7/F': 113, 'Bo7': 114, 'Bsus': 115, 'C': 116, 'C#': 117, 'C#7': 118, 'C#7 add #9': 119, 'C#dim': 120, 'C#dim/B-': 121, 'C#m': 122, 'C#m11': 123, 'C#m13': 124, 'C#m7': 125, 'C#maj7': 126, 'C#o7': 127, 'C#sus add 7': 128, 'C/A': 129, 'C/A-': 130, 'C/B-': 131, 'C/D': 132, 'C/E': 133, 'C/F#': 134, 'C/G': 135, 'C13': 136, 'C6': 137, 'C7': 138, 'C7 add #11': 139, 'C7 add #9': 140, 'C7 add b9': 141, 'C7+': 142, 'C7/E': 143, 'C7/G': 144, 'C9': 145, 'CM9': 146, 'Cdim': 147, 'Cm': 148, 'Cm/B-': 149, 'Cm11': 150, 'Cm13': 151, 'Cm6': 152, 'Cm7': 153, 'Cm7 alter b5': 154, 'Cm7/B-': 155, 'Cm9': 156, 'CmM7': 157, 'Cmaj7': 158, 'Cmaj7 add #11': 159, 'Co7': 160, 'Co7/B-': 161, 'Csus': 162, 'Csus add 7': 163, 'D': 164, 'D#': 165, 'D#7': 166, 'D#9': 167, 'D#dim': 168, 'D#m': 169, 'D#m7': 170, 'D#m7 alter b5': 171, 'D#o7': 172, 'D-': 173, 'D-/A-': 174, 'D-/E': 175, 'D-/E-': 176, 'D-/G-': 177, 'D-13': 178, 'D-7': 179, 'D-7 add #11': 180, 'D-7 add #9': 181, 'D-9': 182, 'D-M13': 183, 'D-M13 alter #11': 184, 'D-M9': 185, 'D-dim': 186, 'D-m': 187, 'D-m11': 188, 'D-m7': 189, 'D-maj7': 190, 'D-maj7/C': 191, 'D-maj7/E-': 192, 'D-maj7/F': 193, 'D-sus add 7': 194, 'D/E': 195, 'D13': 196, 'D6': 197, 'D7': 198, 'D7 add #11': 199, 'D7 add #9': 200, 'D7 add b9': 201, 'D7+': 202, 'D9': 203, 'D9 add #11': 204, 'DM13': 205, 'Ddim': 206, 'Dm': 207, 'Dm/A': 208, 'Dm/E': 209, 'Dm11': 210, 'Dm13': 211, 'Dm7': 212, 'Dm7 alter b5': 213, 'Dm9': 214, 'DmM7 add 9': 215, 'Dmaj7': 216, 'Dmaj7/E': 217, 'Do7': 218, 'Dpower': 219, 'Dsus add 7': 220, 'E': 221, 'E-': 222, 'E-/B-': 223, 'E-/D-': 224, 'E-/E': 225, 'E-13': 226, 'E-6': 227, 'E-7': 228, 'E-7 add #11': 229, 'E-7 add #9': 230, 'E-7 add #9 add #11': 231, 'E-7 alter b5': 232, 'E-9': 233, 'E-9 add #11': 234, 'E-9 add 13': 235, 'E-M13': 236, 'E-M13 alter #11': 237, 'E-M9': 238, 'E-dim': 239, 'E-m': 240, 'E-m11': 241, 'E-m6': 242, 'E-m7': 243, 'E-m7/D-': 244, 'E-m9': 245, 'E-maj7': 246, 'E-maj7/F': 247, 'E-o7': 248, 'E/A': 249, 'E/B': 250, 'E/D': 251, 'E/F': 252, 'E/F#': 253, 'E/G#': 254, 'E13': 255, 'E7': 256, 'E7 add #9': 257, 'E7 alter #5': 258, 'E7 alter b5': 259, 'E7+': 260, 'E7/F': 261, 'E9': 262, 'EM13': 263, 'EM9': 264, 'Edim': 265, 'Em': 266, 'Em11': 267, 'Em13': 268, 'Em7': 269, 'Em7 add 11': 270, 'Em7 alter b5': 271, 'Em7/D': 272, 'Em9': 273, 'Emaj7': 274, 'Emaj7/F#': 275, 'Eo7': 276, 'Esus add 7': 277, 'F': 278, 'F#': 279, 'F#/E': 280, 'F#13': 281, 'F#7': 282, 'F#7 add #9': 283, 'F#7 alter b5': 284, 'F#7/C#': 285, 'F#9': 286, 'F#dim': 287, 'F#m': 288, 'F#m/A': 289, 'F#m11': 290, 'F#m7': 291, 'F#m7 add 11': 292, 'F#m7 alter b5': 293, 'F#m9': 294, 'F#maj7': 295, 'F#o7': 296, 'F#sus': 297, 'F#sus add 7': 298, 'F/B-': 299, 'F/C': 300, 'F/E': 301, 'F/E-': 302, 'F/G': 303, 'F13': 304, 'F6': 305, 'F7': 306, 'F7 add #11': 307, 'F7 add #9': 308, 'F7/A': 309, 'F9': 310, 'F9 add #11': 311, 'Fdim': 312, 'Fm': 313, 'Fm/C': 314, 'Fm11': 315, 'Fm6': 316, 'Fm7': 317, 'Fm7 alter b5': 318, 'Fm7/E-': 319, 'Fm9': 320, 'FmM7': 321, 'FmM7 add 9': 322, 'Fmaj7': 323, 'Fmaj7/G': 324, 'Fo7': 325, 'Fpower': 326, 'Fsus add 7': 327, 'G': 328, 'G#': 329, 'G#7': 330, 'G#m7': 331, 'G#maj7': 332, 'G#o7': 333, 'G#sus': 334, 'G-': 335, 'G-13': 336, 'G-7': 337, 'G-7 add #11': 338, 'G-9': 339, 'G-M13 alter #11': 340, 'G-dim': 341, 'G-m': 342, 'G-m11': 343, 'G-m7': 344, 'G-m9': 345, 'G-maj7': 346, 'G-maj7/A-': 347, 'G-o7': 348, 'G-sus': 349, 'G/A': 350, 'G/B-': 351, 'G13': 352, 'G6': 353, 'G7': 354, 'G7 add #11': 355, 'G7 add #9': 356, 'G7 add b9': 357, 'G7 alter #5': 358, 'G7 alter b5': 359, 'G7/F': 360, 'G9': 361, 'GM13': 362, 'Gdim': 363, 'Gm': 364, 'Gm/B-': 365, 'Gm/E': 366, 'Gm/F': 367, 'Gm/G-': 368, 'Gm11': 369, 'Gm13': 370, 'Gm6': 371, 'Gm7': 372, 'Gm7 alter b5': 373, 'Gm7/B-': 374, 'Gm7/C': 375, 'Gm7/F': 376, 'Gm7/G-': 377, 'Gm9': 378, 'GmM7': 379, 'GmM7 add 9 add 11': 380, 'Gmaj7': 381, 'Gmaj7/A': 382, 'Gsus': 383}
idx2chord = {idx: chord for chord, idx in chord2idx.items()}

# Hyperparameters (adjust as needed)
vocab_size = len(note2idx) + 1  # +1 for padding
embed_dim = 32
num_heads = 4
hidden_dim = 256
num_layers = 4
num_classes = len(chord2idx)
max_seq_length = 32
num_octaves = 10

# ----- Load the model for prediction -----
model_pred = ChordPredictor(vocab_size, embed_dim, num_heads, hidden_dim, num_layers, num_classes, max_seq_length, num_octaves)
state_dict = load_file(safetensors_path)
model_pred.load_state_dict(state_dict)
model_pred.eval()

# ----- Prepare a sample input -----
# Define a sample sequence of notes, octaves, and durations (adjust as needed)
sample_notes = ['C', 'D', 'E', 'F', 'G', 'A', 'B', 'C']
sample_octaves = [3, 3, 3, 4, 4, 4, 4, 3]
sample_durations = [0.5, 0.5, 0.75, 0.5, 1.0, 0.5, 0.25, 0.75]

# Encode the sample notes using note2idx
sample_tokens = [note2idx[note] for note in sample_notes]

# Pad the sequences to max_seq_length
pad_length = max_seq_length - len(sample_tokens)
sample_tokens_padded = sample_tokens + [0] * pad_length
sample_octaves_padded = sample_octaves + [0] * pad_length
sample_durations_padded = sample_durations + [0.0] * pad_length

# Convert to tensors (batch size = 1)
tokens_tensor = torch.tensor([sample_tokens_padded], dtype=torch.long)
octaves_tensor = torch.tensor([sample_octaves_padded], dtype=torch.long)
durations_tensor = torch.tensor([sample_durations_padded], dtype=torch.float)

# ----- Decode and display the input sequence -----
decoded_input = [idx2note.get(tok, 'PAD') if tok != 0 else 'PAD' for tok in sample_tokens_padded]
print("Decoded Input Notes:")
print(decoded_input)

# ----- Make a prediction -----
with torch.no_grad():
    chord_logits, chord_duration = model_pred(tokens_tensor, octaves_tensor, durations_tensor)
    # Convert logits to predicted chord index
    predicted_chord_idx = torch.argmax(chord_logits, dim=-1).item()
    predicted_chord = idx2chord[predicted_chord_idx]
    predicted_duration = chord_duration.item()

# ----- Display the prediction -----
print("\nPredicted Chord:")
print(predicted_chord)
print("\nPredicted Chord Duration:")
print(predicted_duration)


Decoded Input Notes:
['C', 'D', 'E', 'F', 'G', 'A', 'B', 'C', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']

Predicted Chord:
F7

Predicted Chord Duration:
3.54272723197937
