In [None]:
from numpy.core.numeric import argwhere

from data.src.feature_extractors import rhythm_density_sync_score
%cd ..

In [None]:
import yaml
import numpy as np
import os
from data import get_flexcontrol_triplestream_dataset
import matplotlib.pyplot as plt
# suppress matplotlib warnings
import logging
import matplotlib
logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)

In [None]:
# config = yaml.safe_load(open('helpers/configs/FlexControlTripleStreams_0.5.yaml', 'r'))

config = {
    'dataset_root_path': 'data/triple_streams/model_ready/AccentAt0.75/',

    'dataset_files': [
        '01_candombe_four_voices.pkl.bz2',
        '02_elbg_both_flattened_left_right.pkl.bz2',
        '03_groove_midi_crash_hhclosed_hhopen_ride.pkl.bz2',
        '04_groove_midi_hh_kick_snare_toms.pkl.bz2',
        '05_groove_midi_hi_lo_mid_ride.pkl.bz2',
        '06_lmd_bass_brass_drum_percussion.pkl.bz2',
        '07_lmd_bass_brass_drum_percussive.pkl.bz2',
        '08_lmd_bass_brass_guitar_percussion.pkl.bz2',
        '09_lmd_bass_brass_guitar_percussive.pkl.bz2',
        '10_lmd_bass_brass_guitar_piano.pkl.bz2',
        '11_lmd_bass_brass_percussion_percussive.pkl.bz2',
        '12_lmd_bass_brass_percussion_piano.pkl.bz2',
        '13_lmd_bass_brass_percussive_piano.pkl.bz2',
        '14_lmd_bass_drum_guitar_percussion.pkl.bz2',
        '15_lmd_bass_drum_guitar_percussive.pkl.bz2',
        '16_lmd_bass_drum_percussion_percussive.pkl.bz2',
        '17_lmd_bass_drum_percussion_piano.pkl.bz2',
        '18_lmd_bass_drum_percussive_piano.pkl.bz2',
        '19_lmd_bass_guitar_percussion_percussive.pkl.bz2',
        '20_lmd_bass_guitar_percussion_piano.pkl.bz2',
        '21_lmd_bass_guitar_percussive_piano.pkl.bz2',
        '22_lmd_bass_percussion_percussive_piano.pkl.bz2',
        '23_lmd_brass_drum_guitar_percussion.pkl.bz2',
        '24_lmd_brass_drum_guitar_percussive.pkl.bz2',
        '25_lmd_brass_drum_guitar_piano.pkl.bz2',
        '26_lmd_brass_drum_percussion_percussive.pkl.bz2',
        '27_lmd_brass_drum_percussion_piano.pkl.bz2',
        '28_lmd_brass_drum_percussive_piano.pkl.bz2',
        '29_lmd_brass_guitar_percussion_percussive.pkl.bz2',
        '30_lmd_brass_guitar_percussion_piano.pkl.bz2',
        '31_lmd_brass_guitar_percussive_piano.pkl.bz2',
        '32_lmd_brass_percussion_percussive_piano.pkl.bz2',
        '33_lmd_drum_guitar_percussion_percussive.pkl.bz2',
        '34_lmd_drum_guitar_percussion_piano.pkl.bz2',
        '35_lmd_drum_guitar_percussive_piano.pkl.bz2',
        '36_lmd_drum_percussion_percussive_piano.pkl.bz2',
        '37_lmd_guitar_percussion_percussive_piano.pkl.bz2',
        '38_ttd_both-is-and_both_flattened_left_right.pkl.bz2',
        '39_ttd_both-is-or_both_flattened_left_right.pkl.bz2'

    ],

    'max_len': 32,

    'n_encoding_control_tokens':
        [
        ],
    'encoding_control_keys':
        [
        ],

    'n_decoding_control_tokens':
        [
            None,
            None,
            None,
            None,
            None
        ],
    'decoding_control_keys':
        [
            'Total Out Hits',
            'Output Step Density',
            'Stream 1 Relative Density',
            'Stream 2 Relative Density',
            'Stream 3 Relative Density',
        ],
}

is_testing = False

dataset = get_flexcontrol_triplestream_dataset(
        config=config,
        subset_tag="train",
        use_cached=True,
        downsampled_size=2000 if is_testing else None,
        print_logs=False                                #<---  Set to True to print dataset loading logs
    )

In [None]:
# get indices of samples with no microtiming

in_offsets_sum = dataset.input_grooves.cpu().numpy()[:, :, 2:3].sum(-1).sum(-1)
in_offsets_sum_zero_indices = np.argwhere(in_offsets_sum == 0)

out_offsets_sum = dataset.output_streams.cpu().numpy()[:, :, 6:9].sum(-1).sum(-1)
out_offsets_sum_zero_indices = np.argwhere(out_offsets_sum == 0)


len(in_offsets_sum_zero_indices), len(out_offsets_sum_zero_indices), len(np.intersect1d(in_offsets_sum_zero_indices, out_offsets_sum_zero_indices).tolist())



In [None]:
np.argwhere(out_offsets_sum == 0)

In [None]:
straight_hit_pattenrs = {
    "4": np.array([1, 0, 0, 0, 0, 0, 0, 0] * 4),
    "8": np.array([1, 0, 0, 0] * 8),
    "16": np.array([1, 0] * 16),
    "32": np.array([1] * 32),
}

input_straight_pattern_indices = {}
inputs = dataset.input_grooves.cpu().numpy()[:, :, 0]
for k, v in straight_hit_pattenrs.items():
    input_straight_pattern_indices[k] = np.where(np.all(inputs == v, axis=1))[0]
    print(f"{k} straight pattern indices: {input_straight_pattern_indices[k].shape[0]}")

print()

output_straight_pattern_indices = {}
outputs = dataset.flat_output_streams.cpu().numpy()[:, :, 0]
for k, v in straight_hit_pattenrs.items():
    output_straight_pattern_indices[k] = np.where(np.all(outputs == v, axis=1))[0]
    print(f"{k} straight pattern indices: {output_straight_pattern_indices[k].shape[0]}")

print()

both_input_output_straight_pattern_indices = {}
for k, v in straight_hit_pattenrs.items():
    both_input_output_straight_pattern_indices[k] = np.where(
        np.all(inputs == v, axis=1) & np.all(outputs == v, axis=1)
    )[0]
    print(f"{k} both input and output straight pattern indices: {both_input_output_straight_pattern_indices[k].shape[0]}")




In [None]:
from data import FlexControlGroove2TripleStream2BarDataset

features_all = FlexControlGroove2TripleStream2BarDataset.extract_features_dict(
    {
        "input_hvos": dataset.input_grooves.cpu().numpy(),
        "output_hvos": dataset.output_streams.cpu().numpy(),
        "flat_out_hvos": dataset.flat_output_streams.cpu().numpy(),
    },
)

In [None]:
both_input_output_straight_pattern_indices

In [None]:
rhythm_density_sync_scores = features_all["Intra Stream Exclusiveness"]

from data.src.dataLoaders import tokenize_control_feature_array

tokenized = tokenize_control_feature_array(
    control_array=rhythm_density_sync_scores,
    n_bins=33,
    low=0,
    high=0.85
)
# Plot the distribution of rhythm density sync scores
plt.figure(figsize=(12, 6))
plt.hist(rhythm_density_sync_scores, bins=13, alpha=0.7, color='blue')
plt.xlabel('Rhythm Density Synchronization Score')
plt.ylabel('Frequency')
plt.xlim([0, 1])
plt.title('Distribution of Rhythm Density Synchronization Scores')
plt.grid(True)
plt.tight_layout()

from collections import Counter
counts = Counter(tokenized)

# plot counts
plt.figure(figsize=(12, 6))
plt.bar(counts.keys(), counts.values(), alpha=0.7, color='blue')
plt.xlabel('Tokenized Rhythm Density Synchronization Score')
plt.ylabel('Frequency')
plt.title('Tokenized Rhythm Density Synchronization Score Distribution')


In [None]:
# Balance/Evenness/IOI Entropy

balance = features_all["Balance | Input + Output"]
evenness = features_all["Evenness | Input + Output"]
ioi_entropy = features_all["IOI Entropy | Input + Output"]

# plot scatter of balance vs evenness, balance vs ioi_entropy, evenness vs ioi_entropy. Use the third features in each plot as color
plt.figure(figsize=(12, 6))
plt.scatter(balance, evenness, c=ioi_entropy, s=1, alpha=0.1)
plt.xlabel('Balance')
plt.ylabel('Evenness')
plt.title('Balance vs Evenness (Colored by IOI Entropy)')
plt.colorbar(label='IOI Entropy')
plt.figure(figsize=(12, 6))
plt.scatter(balance, ioi_entropy, c=evenness, s=1, alpha=0.1)
plt.xlabel('Balance')
plt.ylabel('IOI Entropy')
plt.title('Balance vs IOI Entropy (Colored by Evenness)')
plt.colorbar(label='Evenness')
plt.figure(figsize=(12, 6))
plt.scatter(evenness, ioi_entropy, c=balance, s=1, alpha=0.1)
plt.xlabel('Evenness')
plt.ylabel('IOI Entropy')
plt.title('Evenness vs IOI Entropy (Colored by Balance)')
plt.colorbar(label='Balance')


In [None]:
input_hits = dataset.input_grooves.cpu().numpy()[:, :, 0]
input_vels = dataset.input_grooves.cpu().numpy()[:, :, 1]
input_offsets = dataset.input_grooves.cpu().numpy()[:, :, 2]
output_hits = dataset.flat_output_streams.cpu().numpy()[:, :, 0]
output_vels = dataset.flat_output_streams.cpu().numpy()[:, :, 1]
output_offsets = dataset.flat_output_streams.cpu().numpy()[:, :, 2]

# Find indices where input hits are 1
def count_unique_nonzero(row):
    non_zero_values = row[row != 0]
    return len(np.unique(non_zero_values))

output_unique_counts = np.apply_along_axis(count_unique_nonzero, axis=1, arr=output_vels)

In [None]:
from collections import Counter
Counter(output_unique_counts.tolist())

In [None]:
input_per_step_counts = np.sum(input_hits, axis=0)
output_per_step_counts = np.sum(output_hits, axis=0)

# plot in a single figure, put corresponding bars next to each other

plt.figure(figsize=(12, 6))
plt.bar(np.arange(len(input_per_step_counts)), input_per_step_counts, width=0.4, label='Input', alpha=0.7)
plt.bar(np.arange(len(output_per_step_counts)) + 0.4, output_per_step_counts, width=0.4, label='Output (Flattened)', alpha=0.7)
plt.xlabel('Step')
plt.ylabel('Number of Samples')
plt.title('Number of Active Steps in Input and Output Streams')
plt.xticks(np.arange(len(input_per_step_counts)) + 0.2, np.arange(len(input_per_step_counts)))
plt.legend()
plt.tight_layout()




In [None]:
# plot means of velocities for input and output streams
plt.figure(figsize=(12, 6))
plt.bar(np.arange(32), np.mean(input_vels, axis=0), width=0.4, label='Input', alpha=0.7)
plt.bar(np.arange(32) + 0.4, np.mean(output_vels, axis=0), width=0.4, label='Output (Flattened)', alpha=0.7)
plt.xlabel('Step')
plt.ylabel('Mean Velocity')
plt.title('Mean Velocities in Input and Output Streams')
plt.xticks(np.arange(32) + 0.2, np.arange(32))
plt.legend()
plt.tight_layout()



In [None]:
# plot means of offsets for input and output streams
plt.figure(figsize=(12, 6))
plt.bar(np.arange(32), np.mean(input_offsets, axis=0), width=0.4, label='Input', alpha=0.7)
plt.bar(np.arange(32) + 0.4, np.mean(output_offsets, axis=0), width=0.4, label='Output (Flattened)', alpha=0.7)
plt.xlabel('Step')
plt.ylabel('Mean Offset')
plt.title('Mean Offsets in Input and Output Streams')
plt.xticks(np.arange(32) + 0.2, np.arange(32))
plt.legend()
plt.tight_layout()


In [None]:
input_hit_sync = features_all["Syncopation | Input | Velocity"]
input_complexity = features_all["Complexity | Input"]
output_Velocity_sync = features_all["Syncopation | Output | Velocity"]
output_complexity = features_all["Complexity | Output"]

In [None]:
# do a scatter plot of input sync vs complexity
plt.figure(figsize=(12, 6))
plt.scatter(input_hit_sync, output_Velocity_sync, s=1, alpha=0.1)

In [None]:
stream1_hit_sync = features_all["Syncopation | Stream 1 | Hit"]
stream2_hit_sync = features_all["Syncopation | Stream 2 | Hit"]
stream3_hit_sync = features_all["Syncopation | Stream 3 | Hit"]
stream1_complexity = features_all["Complexity | Stream 1"]
stream2_complexity = features_all["Complexity | Stream 2"]
stream3_complexity = features_all["Complexity | Stream 3"]

plt.figure(figsize=(12, 6))
plt.scatter(output_hit_sync, stream1_complexity, s=1, alpha=0.05)
plt.figure(figsize=(12, 6))
plt.scatter(output_hit_sync, stream2_complexity, s=1, alpha=0.05)
plt.figure(figsize=(12, 6))
plt.scatter(output_hit_sync, stream3_complexity, s=1, alpha=0.05)