In [81]:
import pandas as pd
import numpy as np
import re

file_path = './data/results/tokenized/bart_reg_1.0/RootPCTokenizer.csv'

df = pd.read_csv(file_path)


# Convert the 'real' and 'generated' columns to lists
real_list = df["real"].tolist()
generated_list = df["generated"].tolist()


In [82]:
def preprocess_sequence(seq):
    tokens = seq.split()
    
    # 1) Remove the leading <h> token if present
    if tokens and tokens[0] == "<h>":
        tokens.pop(0)  # remove <h>
    
    # 2) If the *first* token is <bar> and the *second* token is also <bar>, remove the second one
    if len(tokens) > 1 and tokens[0] == "<bar>" and tokens[1] == "<bar>":
        tokens.pop(1)  # remove the duplicate <bar>
    
    # Return the cleaned-up string
    return " ".join(tokens)

# Preprocess both lists
real_list = [preprocess_sequence(seq) for seq in real_list]
generated_list = [preprocess_sequence(seq) for seq in generated_list]

In [83]:
gt_seq = real_list[5]
print(gt_seq)
pred_seq = generated_list[5]
print(pred_seq)

<bar> position_0x00 chord_root_5 chord_pc_9 chord_pc_0 position_1x50 chord_root_0 chord_pc_4 chord_pc_7 <bar> position_0x00 chord_root_7 chord_pc_11 chord_pc_2 position_1x50 chord_root_9 chord_pc_0 chord_pc_4 <bar> position_0x00 chord_root_5 chord_pc_9 chord_pc_0 position_1x50 chord_root_0 chord_pc_4 chord_pc_7 <bar> position_0x00 chord_root_7 chord_pc_11 chord_pc_2 position_1x50 chord_root_9 chord_pc_0 chord_pc_4 <bar> position_0x00 chord_root_5 chord_pc_9 chord_pc_0 position_1x50 chord_root_0 chord_pc_4 chord_pc_7 <bar> position_0x00 chord_root_7 chord_pc_11 chord_pc_2 position_1x50 chord_root_9 chord_pc_0 chord_pc_4 <bar> position_0x00 chord_root_5 chord_pc_9 chord_pc_0 position_1x50 chord_root_0 chord_pc_4 chord_pc_7 <bar> position_0x00 chord_root_9 chord_pc_0 chord_pc_4 position_1x50 chord_root_7 chord_pc_11 chord_pc_2 </s>
<bar> position_0x00 chord_root_0 chord_pc_4 chord_pc_7 <bar> position_0x00 chord_root_7 chord_pc_11 chord_pc_2 <bar> position_0x00 chord_root_9 chord_pc_0 chor

In [84]:

# RootPCTokenizer
def check_token_consistency(seq):
    """
    Checks the consistency of token transitions based on predefined rules, 
    including cases where chords are represented with separate root and pitch class tokens.

    Args:
        seq (list): A list of tokens in the sequence.

    Returns:
        float: Consistency ratio (percentage of correct transitions).
    """

    # Regular expression patterns for token types
    position_pattern = re.compile(r'position_\d+x\d+')  # Matches position_AxBB
    chord_root_pattern = re.compile(r'chord_root_\d{1,2}')  # Matches chord root tokens like chord_root_3
    pitch_class_pattern = re.compile(r'chord_pc_\d{1,2}')  # Matches pitch class tokens like chord_pc_10

    total_transitions = 0
    valid_transitions = 0
    i = 0

    while i < len(seq) - 1:
        current_token = seq[i]
        next_token = seq[i + 1]
        total_transitions += 1

        # Rule 1: <bar> must be followed by position_AxBB
        if current_token == "<bar>":
            if position_pattern.match(next_token):
                valid_transitions += 1
        # Rule 2: position_AxBB must be followed by chord_root_X
        elif position_pattern.match(current_token):
            if chord_root_pattern.match(next_token):
                valid_transitions += 1
        # Rule 3: chord_root_X must be followed by 3, 4, or 5 chord_pc_X tokens
        elif chord_root_pattern.match(current_token):
            chord_count = 0
            while i + 1 < len(seq) and pitch_class_pattern.match(seq[i + 1]):
                chord_count += 1
                i += 1  # Move to the next chord_pc_X token
            
            if chord_count in {2, 3, 4}:
                valid_transitions += 1  # Chord group size is valid
        # Rule 4: chord_pc_X tokens must be followed by position_AxBB, <bar>, or </s>
        elif pitch_class_pattern.match(current_token):
            if position_pattern.match(next_token) or next_token in ["<bar>", "</s>"]:
                valid_transitions += 1

        i += 1  # Move to the next token

    # Calculate consistency ratio as a percentage
    consistency_ratio = (valid_transitions / total_transitions) * 100 if total_transitions > 0 else 100

    return consistency_ratio



def calculate_duplicate_error_ratio(seq):
    """
    Calculate the total duplicate error ratio for the sequence.

    Args:
        seq (list): List of tokens in the sequence.

    Returns:
        float: The overall duplicate error ratio as a percentage.
    """
    duplicate_count = 0
    total_tokens = len(seq)

    # Count consecutive duplicate tokens
    for i in range(len(seq) - 1):
        if seq[i] == seq[i + 1]:  # Consecutive duplicate found
            duplicate_count += 1

    # Calculate the total duplicate ratio as a percentage
    duplicate_ratio = (duplicate_count / total_tokens) * 100 if total_tokens > 0 else 0

    return duplicate_ratio


def count_bar_tokens(gt_seq, pred_seq):
    # Count occurrences of '<bar>' in both sequences
    gt_bar_count = gt_seq.count('<bar>')
    pred_bar_count = pred_seq.count('<bar>')
    
    # Calculate the difference only if counts are not identical
    bar_diff = abs(gt_bar_count - pred_bar_count) if gt_bar_count != pred_bar_count else 0

    # Count exact matches 
    is_bar_count_correct = gt_bar_count == pred_bar_count

    return bar_diff, is_bar_count_correct


def low_level_metrics(gt_seqs, pred_seqs):
    token_diff = []
    bar_diff_list = []
    num_no_eos = 0
    num_bar_identical = 0
    duplicate_error_ratios = []
    consistency_token_ratios = []

    prob_seqs = []

    for i in range(0, len(gt_seqs)):

        gt_seq = gt_seqs[i]
        gt_seq = gt_seq.split()
        pred_seq = pred_seqs[i] 
        pred_seq = pred_seq.split()

    
        # Get predicted sequence
        try:
            eos_idx_pred = pred_seq.index('</s>')  # Find the first occurrence of </s> in predictions
            token_diff.append(len(gt_seq) - len(pred_seq))
        except ValueError:
            num_no_eos += 1  # Count sequences without </s>
    

        # Calculate bar token differences and matches
        bar_diff, bar_match = count_bar_tokens(gt_seq, pred_seq)

        if bar_match:
            num_bar_identical += 1
        else:
            bar_diff_list.append(bar_diff)
            prob_seqs.append(pred_seq)


        # Calculate duplicate token error ratio
        duplicate_ratio = calculate_duplicate_error_ratio(pred_seq)
        duplicate_error_ratios.append(duplicate_ratio)

        # Calculate token consistency ratio
        consistency_ratio = check_token_consistency(pred_seq)
        consistency_token_ratios.append(consistency_ratio)


    return token_diff, num_no_eos, bar_diff_list, num_bar_identical, duplicate_error_ratios, consistency_token_ratios, prob_seqs 

In [85]:
token_diff, num_no_eos, bar_diff, num_bar_identical, duplicate_error_ratios, consistency_token_ratios, prob_seqs  = low_level_metrics(real_list, generated_list)

print(f"Sequences without </s>: {num_no_eos}")
print(f"Number of sequences with identical bar counts: {num_bar_identical}")
token_diff_mean = np.mean(token_diff)
token_diff_std = np.std(token_diff)
print(f"Total Token Length Differences: Mean = {token_diff_mean:.2f}, Std = {token_diff_std:.2f}")
bar_diff_mean = np.mean(bar_diff)
bar_diff_std = np.std(bar_diff)
print(f"Bar Token Differences: Mean = {bar_diff_mean:.2f}, Std = {bar_diff_std:.2f}")
duplicate_error_mean = np.mean(duplicate_error_ratios)
duplicate_error_std = np.std(duplicate_error_ratios)
print(f"Duplicate Token Error Ratio: Mean = {duplicate_error_mean:.2f}, Std = {duplicate_error_std:.2f}")
consistency_ratio_mean = np.mean(consistency_token_ratios)
consistency_ratio_std = np.std(consistency_token_ratios)
print(f"Token Consistency Ratio: Mean = {consistency_ratio_mean:.2f}, Std = {consistency_ratio_std:.2f}")



Sequences without </s>: 0
Number of sequences with identical bar counts: 679
Total Token Length Differences: Mean = -14.66, Std = 72.07
Bar Token Differences: Mean = 2.57, Std = 4.29
Duplicate Token Error Ratio: Mean = 0.38, Std = 3.76
Token Consistency Ratio: Mean = 98.91, Std = 5.71


In [95]:
aseq = prob_seqs[5]