In [1]:
import os

# Create output directory
output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)
song_path = 'Bon_Iver_St._Vincent_-_Roslyn_Lyrics.mp3'

song_name = os.path.splitext(os.path.basename(song_path))[0].replace('_', '')
os.makedirs(f'{output_dir}/{song_name}', exist_ok=True)


In [2]:
import torchaudio as ta
print("Available:", ta.list_audio_backends())

Available: ['soundfile']


In [3]:
import torchaudio
import torch
from demucs.pretrained import get_model
from demucs.apply import apply_model

import subprocess
result = subprocess.run(['which', 'ffmpeg'], capture_output=True, text=True)
print(f"FFmpeg location: {result.stdout.strip()}")

# Load the audio
print(f"Loading audio from {song_path}...")
sample_waveform, sample_rate = torchaudio.load(song_path)
print(f"Loaded audio with shape {sample_waveform.shape} and sample rate {sample_rate}")

# Determine device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

def prepare_audio(waveform, source_sr, target_sr):
    """Prepare audio for model input"""
    # Resample if needed
    if source_sr != target_sr:
        waveform = torchaudio.functional.resample(waveform, source_sr, target_sr)
        
    # Handle channels
    if waveform.shape[0] > 2:
        waveform = waveform[:2, :]
    elif waveform.shape[0] == 1:
        waveform = torch.cat([waveform, waveform], dim=0)
        
    return waveform

def enhance_guitar(guitar_waveform, sample_rate):
    """Enhance guitar with filters and transient processing"""
    # Add batch dimension if needed
    if guitar_waveform.dim() == 2:
        guitar_waveform = guitar_waveform.unsqueeze(0)
        
    # Now we can safely unpack dimensions
    b, c, t = guitar_waveform.shape
    
    # FFT for frequency domain processing
    guitar_waveform_freq = torch.fft.rfft(guitar_waveform, dim=2)
    
    # Create high-pass filter (reduce below 80Hz)
    freqs = torch.fft.rfftfreq(t, d=1/sample_rate)
    high_pass = (1 - torch.exp(-freqs/80))
    
    # Create mid boost around 2-4kHz (presence)
    mid_boost = 1.0 + 0.5 * torch.exp(-((freqs - 3000)/500)**2)
    
    # Apply filters
    filter_curve = high_pass.view(1, 1, -1) * mid_boost.view(1, 1, -1)
    guitar_waveform_freq *= filter_curve
    
    # Back to time domain
    guitar_waveform = torch.fft.irfft(guitar_waveform_freq, n=t, dim=2)
    
    # Apply subtle compression
    peak = guitar_waveform.abs().max()
    if peak > 0:
        # Simple soft knee compression
        threshold = 0.7
        ratio = 3.0
        gain = 1.2
        
        above_thresh = (guitar_waveform.abs() > threshold * peak).float()
        comp_factor = 1.0 - above_thresh * (1.0 - 1.0/ratio) * (guitar_waveform.abs() - threshold * peak) / (peak * (1.0 - threshold))
        guitar_waveform = guitar_waveform * comp_factor * gain
        
        # Final limiter
        peak = guitar_waveform.abs().max()
        if peak > 0.95:
            guitar_waveform = 0.95 * guitar_waveform / peak
    
    # Remove batch dimension if we added it
    if b == 1:
        guitar_waveform = guitar_waveform.squeeze(0)
        
    return guitar_waveform

# STAGE 1: Extract all stems with htdemucs_ft
print("STAGE 1: Separating with htdemucs_ft...")
model_stage1 = get_model("htdemucs_ft")
model_stage1.eval()
model_stage1.to(device)

# Prepare audio for first model
waveform_stage1 = prepare_audio(sample_waveform, sample_rate, model_stage1.samplerate)
waveform_stage1 = waveform_stage1.to(device)

# Separate first stage
with torch.no_grad():
    sources_stage1 = apply_model(model_stage1, waveform_stage1.unsqueeze(0))[0]
    sources_stage1 = sources_stage1.cpu()

# Get the "other" stem
other_index = model_stage1.sources.index('other') if 'other' in model_stage1.sources else None
if other_index is None:
    print("Warning: 'other' source not found in model 1. Using all non-guitar sources combined.")
    # Combine all sources except guitar to create "other"
    if 'guitar' in model_stage1.sources:
        guitar_index = model_stage1.sources.index('guitar')
        all_sources = torch.zeros_like(sources_stage1[0])
        for i, src in enumerate(model_stage1.sources):
            if i != guitar_index:
                all_sources += sources_stage1[i]
        other_waveform = all_sources
    else:
        # If no guitar source, just use the first stem as "other"
        other_waveform = sources_stage1[0]
else:
    other_waveform = sources_stage1[other_index]

# Save the other stem
other_file = os.path.join(output_dir, song_name, "stage1_other.wav")
torchaudio.save(other_file, other_waveform, model_stage1.samplerate)
print(f"Saved 'other' stem to {other_file}")

# STAGE 2: Extract guitar from "other" stem using htdemucs_6s
print("STAGE 2: Extracting guitar from 'other' using htdemucs_6s...")
model_stage2 = get_model("htdemucs_6s")
model_stage2.eval()
model_stage2.to(device)

# Prepare the "other" stem for second model
other_waveform = prepare_audio(other_waveform, model_stage1.samplerate, model_stage2.samplerate)
other_waveform = other_waveform.to(device)

# Separate second stage
with torch.no_grad():
    sources_stage2 = apply_model(model_stage2, other_waveform.unsqueeze(0))[0]
    sources_stage2 = sources_stage2.cpu()

# Get the guitar from second separation
if 'guitar' in model_stage2.sources:
    guitar_index = model_stage2.sources.index('guitar')
    extracted_guitar = sources_stage2[guitar_index]
    
    # Save the extracted guitar
    guitar_file = os.path.join(output_dir, song_name, "stage2_guitar_from_other.wav")
    torchaudio.save(guitar_file, extracted_guitar, model_stage2.samplerate)
    print(f"Saved extracted guitar to {guitar_file}")
    
    # Enhance and save
    enhanced_guitar = enhance_guitar(extracted_guitar, model_stage2.samplerate)
    enhanced_file = os.path.join(output_dir, song_name, "stage2_guitar_enhanced.wav")
    torchaudio.save(enhanced_file, enhanced_guitar, model_stage2.samplerate)
    print(f"Saved enhanced guitar to {enhanced_file}")
else:
    print("Error: 'guitar' source not found in the second model")

# BONUS: Also get the guitar from the first separation for comparison
if 'guitar' in model_stage1.sources:
    guitar_index = model_stage1.sources.index('guitar')
    original_guitar = sources_stage1[guitar_index]
    
    # Save the original guitar stem
    orig_guitar_file = os.path.join(output_dir, song_name, "stage1_original_guitar.wav")
    torchaudio.save(orig_guitar_file, original_guitar, model_stage1.samplerate)
    print(f"Saved original guitar stem to {orig_guitar_file}")
    
    # Create an enhanced version of the original guitar
    enhanced_orig_guitar = enhance_guitar(original_guitar, model_stage1.samplerate)
    enhanced_orig_file = os.path.join(output_dir, song_name, "stage1_original_guitar_enhanced.wav")
    torchaudio.save(enhanced_orig_file, enhanced_orig_guitar, model_stage1.samplerate)
    print(f"Saved enhanced original guitar to {enhanced_orig_file}")
    
    # FINAL STEP: Try combining both guitar extractions for maximum clarity
    # Resample if needed to match sample rates
    if model_stage1.samplerate != model_stage2.samplerate:
        original_guitar = torchaudio.functional.resample(
            original_guitar, model_stage1.samplerate, model_stage2.samplerate)
    
    # Make sure shapes match
    min_length = min(original_guitar.shape[1], extracted_guitar.shape[1])
    original_guitar = original_guitar[:, :min_length]
    extracted_guitar = extracted_guitar[:, :min_length]
    
    # Blend with 70% from first model, 30% from second model
    combined_guitar = 0.7 * original_guitar + 0.3 * extracted_guitar
    
    # Enhance the combined result
    enhanced_combined = enhance_guitar(combined_guitar, model_stage2.samplerate)
    combined_file = os.path.join(output_dir, song_name, "combined_guitar_enhanced.wav")
    torchaudio.save(combined_file, enhanced_combined, model_stage2.samplerate)
    print(f"Saved combined enhanced guitar to {combined_file}")

print("Processing complete!")

FFmpeg location: /opt/homebrew/bin/ffmpeg
Loading audio from Bon_Iver_St._Vincent_-_Roslyn_Lyrics.mp3...
Loaded audio with shape torch.Size([2, 14800214]) and sample rate 48000
Using device: mps
STAGE 1: Separating with htdemucs_ft...
Saved 'other' stem to outputs/BonIverSt.Vincent-RoslynLyrics/stage1_other.wav
STAGE 2: Extracting guitar from 'other' using htdemucs_6s...
Saved extracted guitar to outputs/BonIverSt.Vincent-RoslynLyrics/stage2_guitar_from_other.wav
Saved enhanced guitar to outputs/BonIverSt.Vincent-RoslynLyrics/stage2_guitar_enhanced.wav
Processing complete!


## Analysis

Audio File → Audio Analysis Model → Extract tempo/rhythm → 
Text Description → LLM → Strumming Suggestions

In [4]:
import torchaudio

def trim_audio(input_file, output_file=None, start_sec=0, end_sec=None):
    """
    Trim audio file to specified start and end times.
    
    Parameters:
    - input_file: Path to the input audio file
    - output_file: Path to save the trimmed file (if None, returns without saving)
    - start_sec: Start time in seconds
    - end_sec: End time in seconds (if None, trims to the end of the file)
    
    Returns:
    - trimmed_waveform: Tensor containing the trimmed audio
    - sample_rate: Sample rate of the audio
    """
    # Load the audio
    waveform, sample_rate = torchaudio.load(input_file)
    
    # Convert time to samples
    start_sample = int(start_sec * sample_rate)
    end_sample = int(end_sec * sample_rate) if end_sec is not None else waveform.shape[1]
    
    # Trim the audio
    trimmed_waveform = waveform[:, start_sample:end_sample]
    
    # Save the trimmed audio if output_file is provided
    if output_file:
        torchaudio.save(output_file, trimmed_waveform, sample_rate)
    
    return trimmed_waveform, sample_rate

trim_audio(f'outputs/{song_name}/stage2_guitar_enhanced.wav', start_sec = 10, end_sec = 100, output_file=f'outputs/{song_name}/stage2_guitar_enhanced_cut.wav')

(tensor([[-0.1056, -0.1083, -0.1123,  ...,  0.0472,  0.0460,  0.0447],
         [ 0.0179,  0.0209,  0.0228,  ...,  0.0922,  0.0914,  0.0905]]),
 44100)

In [4]:
"""dynamic_guitar_strum_analysis.py  –  chords & notes processed **independently**
==========================================================================

* Deep‑Chroma (madmom) → **chord timeline** (segment‑level)
* torchcrepe (or pyin fallback) → **note timeline** (event‑level)
* Original beat‑aligned strum/chord/bar/section logic LEFT INTACT so your UI
  keeps working, but we **do not overwrite chords with notes** anymore.
* Extra DataFrames returned: `chords_timeline`, `notes_timeline`.
* One label per DataFrame → no clobbering; overlapping times are fine.

Install (CPU only):
    pip install librosa madmom torch torchaudio torchcrepe pandas tqdm "numpy<1.24"

If you later add a CUDA PyTorch wheel, torchcrepe will use it automatically.
"""
from __future__ import annotations
import warnings, traceback
from dataclasses import dataclass, asdict
from typing import List, Dict

import numpy as np, librosa, pandas as pd, torch, torchaudio
from pathlib import Path
from tqdm import tqdm

# --------------------------------------------------------------------------
# Dataclasses (unchanged for strums/bars/sections)
# --------------------------------------------------------------------------
@dataclass
class Strum:
    time: float; bar: int; sub_16: int; direction: str; velocity: float; kind: str; label: str   # kind NOTE|CHORD|NONE

@dataclass
class BarSummary:
    bar: int; bit_pattern: str; down_up: str; mean_vel: float; chords: List[str]

@dataclass
class Section:
    start_bar: int; end_bar: int; pattern_bits: str; chords: List[str]

# --------------------------------------------------------------------------
# Polyphony helper (v2) -----------------------------------------------------
# --------------------------------------------------------------------------
def _is_polyphonic(mag_db: np.ndarray,
                   peak_db: float = -35.0,
                   min_peaks: int = 3,
                   dom_margin: float = 8.0) -> bool:
    """
    Return True if the frame is almost certainly a chord.
    Override: if the strongest peak is `dom_margin` dB louder than the
    2nd‑strongest, treat as monophonic even when `min_peaks` is exceeded.
    """
    strong = mag_db > peak_db
    # indices of strong bins
    idx = np.flatnonzero(strong)
    if len(idx) == 0:
        return False

    # dominant‑peak override
    sorted_db = np.sort(mag_db[idx])
    if len(sorted_db) >= 2 and sorted_db[-1] - sorted_db[-2] >= dom_margin:
        return False                            # clearly one string dominates

    # otherwise count distinct strong groups
    groups = np.split(strong, np.flatnonzero(~strong) + 1)
    n_peaks = sum(g.any() for g in groups)
    return n_peaks >= min_peaks


def spectral_centroid_direction(y: np.ndarray, sr: int, onset_frames: np.ndarray) -> List[str]:
    """Classify Down/Up strokes by sign of spectral‑centroid slope around attack."""
    cent = librosa.feature.spectral_centroid(y=y, sr=sr, hop_length=512)[0]
    dirs = []
    for f in onset_frames:
        a = max(0, f-2)
        b = min(len(cent)-1, f+2)
        dirs.append('D' if cent[b] - cent[a] < 0 else 'U')
    return dirs

# --------------------------------------------------------------------------
# 1.  Deep‑Chroma chord timeline (segment‑level) ----------------------------
# --------------------------------------------------------------------------

def chord_timeline(audio_path: str) -> pd.DataFrame:
    """Return DF with columns [start, end, chord]."""
    from madmom.audio.chroma import DeepChromaProcessor
    from madmom.features.chords import DeepChromaChordRecognitionProcessor
    chroma = DeepChromaProcessor()(audio_path)
    segs   = DeepChromaChordRecognitionProcessor()(chroma)
    df = pd.DataFrame(segs, columns=["start", "end", "label"])
    df["label"] = df["label"].str.split("/").str[0]
    return df

# --------------------------------------------------------------------------
# 2.  torchcrepe / pyin note timeline (event‑level) -------------------------
# --------------------------------------------------------------------------

# --------------------------------------------------------------------------
# 2.  Gated note timeline (event‑level, single‑string only) -----------------
# --------------------------------------------------------------------------
def note_timeline(audio_path: str,
                  hop_s: float = 0.01,
                  conf_thresh: float = .8,
                  peak_db: float = -52,
                  min_peaks: int = 4,
                  device: str | None = None) -> pd.DataFrame:
    """
    Returns DF [time, note] containing ONLY intentionally plucked single‑string notes.
    Strategy:
      1. Find onsets (same settings as the rest of the pipeline).
      2. For each onset grab a 40 ms slice and CQT → run _is_polyphonic().
      3. Only if that slice is *not* polyphonic do we call torchcrepe/pyin.
    """
    import torchcrepe, torchcrepe.decode as tcd, torchcrepe.filter as tcf

    # --- load + onset detection ------------------------------------------
    y, sr = librosa.load(audio_path, sr=None, mono=True)
    onset_env   = librosa.onset.onset_strength(y=y, sr=sr)
    onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr)
    onset_times  = librosa.frames_to_time(onset_frames, sr=sr)

    # --- constants --------------------------------------------------------
    slice_ms = 40                                   # analysis window
    slice_samps = int(sr * slice_ms / 1000.0)
    hop_len = int(round(16000 * hop_s))

    # --- prepare harmonic layer & CQT for gate ---------------------------
    y_harm, _ = librosa.effects.hpss(y)             # helps both gates
    C = np.abs(librosa.cqt(y_harm,
                           sr=sr,
                           hop_length=512,
                           n_bins=84,
                           bins_per_octave=12))
    C_db = librosa.amplitude_to_db(C, ref=np.max)

    # --- choose device for torchcrepe ------------------------------------
    if device is None:
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")

    # --- resample once for torchcrepe ------------------------------------
    y16 = torchaudio.functional.resample(torch.tensor(y_harm),
                                         sr, 16000) if sr != 16000 else torch.tensor(y_harm)
    y16 = y16.unsqueeze(0).to(device)

    rows = []
    try:
        for t, fr in zip(onset_times, onset_frames):
            # ------------------------------------------------------------------
            # 2·A  Polyphony gate  (CQT frame centred on onset)
            # ------------------------------------------------------------------
            # shift 30 ms (≈ 3 CQT hops at hop_length=512) past the onset
            off = fr + 3
            cqt_frame = C_db[:, off] if off < C_db.shape[1] else C_db[:, -1]
            if _is_polyphonic(cqt_frame, peak_db, min_peaks):
                continue                                  # reject → chord

            # ------------------------------------------------------------------
            # 2·B  Periodicity gate (torchcrepe, harmonic layer only)
            # ------------------------------------------------------------------
            start16 = max(0, int(t * 16000) - hop_len//2)
            end16   = start16 + hop_len
            frame = y16[..., start16:end16]               # shape (1, N)
            f0, pdist = torchcrepe.predict(frame,
                                           16000,
                                           hop_len,
                                           model='full',
                                           decoder=tcd.argmax,
                                           fmin=80, fmax=1200,
                                           batch_size=64,
                                           device=device,
                                           return_periodicity=True)
            f0 = tcf.median(f0, 3)
            hz   = float(f0.squeeze())
            pval = float(pdist.squeeze())
            if pval < conf_thresh or not (80 < hz < 1200):
                continue                                  # weak/confused

            rows.append((t, librosa.hz_to_note(hz, octave=False)))

    except Exception as e:
        warnings.warn(f"torchcrepe failed ({e}); falling back to pyin")
        for t, fr in zip(onset_times, onset_frames):
            if _is_polyphonic(C_db[:, fr], peak_db, min_peaks):
                continue
            start = max(0, fr*512)
            end   = start + slice_samps
            f0, _, _ = librosa.pyin(y[start:end], fmin=80, fmax=1200, sr=sr)
            if f0 is not None and not np.isnan(f0).all():
                hz = float(np.nanmedian(f0))
                rows.append((t, librosa.hz_to_note(hz, octave=False)))

    return pd.DataFrame(rows, columns=["time", "note"])

# --------------------------------------------------------------------------
# 3.  Helper: beat‑level chord map (legacy, for strums/bars) ----------------
# --------------------------------------------------------------------------

def chord_sequence_by_beat(chords_df: pd.DataFrame, beat_times: np.ndarray):
    idx, seg_i = {}, 0
    for b, bt in enumerate(beat_times):
        while seg_i+1 < len(chords_df) and chords_df.iloc[seg_i]['end'] <= bt:
            seg_i += 1
        idx[b] = chords_df.iloc[seg_i]['label']
    return idx

# --------------------------------------------------------------------------
# 4.  Core analysis (strums/bars/sections) – chord map only -----------------
# --------------------------------------------------------------------------

def analyse_audio(audio_path: str, return_dataframes: bool = True):
    y, sr = librosa.load(audio_path, sr=None)
    tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, units='frames', tightness=400)
    tempo = float(np.atleast_1d(tempo)[0])
    beat_times = librosa.frames_to_time(beat_frames, sr=sr)
    # ----- chord & note timelines (independent) -----
    chords_df = chord_timeline(audio_path)
    notes_df  = note_timeline(audio_path)
    beat_chords = chord_sequence_by_beat(chords_df, beat_times)

    # ----- onsets / strums (keep original behaviour) -----
    onset_env = librosa.onset.onset_strength(y=y, sr=sr)
    onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr)
    onset_times  = librosa.frames_to_time(onset_frames, sr=sr)

    grid_step = 60/tempo/4
    grid_times = np.arange(beat_times[0], beat_times[-1]+grid_step, grid_step)

    y_harm, _ = librosa.effects.hpss(y)
    chroma = librosa.feature.chroma_cqt(y=y_harm, sr=sr)
    directions = spectral_centroid_direction(y, sr, onset_frames)
    rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0]

    strums: List[Strum] = []
    for i, (t, fr) in enumerate(zip(onset_times, onset_frames)):
        gidx = int(np.argmin(np.abs(grid_times - t)))
        bar_idx, sub16 = divmod(gidx, 16)
        vel = float(rms[min(len(rms)-1, fr)])
        kind, label = 'CHORD', beat_chords.get(int(np.argmin(np.abs(beat_times - t))), 'N')
        strums.append(Strum(time=float(t), bar=bar_idx+1, sub_16=sub16,
                            direction=directions[i], velocity=vel,
                            kind=kind, label=label))

    # ----- summarise bars/sections (same as before) -----
    bars: Dict[int, BarSummary] = {}
    for s in strums:
        b = bars.setdefault(s.bar, BarSummary(bar=s.bar, bit_pattern=['0']*16,
                     down_up=['-']*16, mean_vel=0.0, chords=[]))
        b.bit_pattern[s.sub_16] = '1'
        b.down_up[s.sub_16] = s.direction
        b.mean_vel += s.velocity
        if s.kind == 'CHORD':
            b.chords.append(s.label)
    for b in bars.values():
        hits = b.bit_pattern.count('1')
        b.mean_vel /= max(1, hits)
        b.bit_pattern = ''.join(b.bit_pattern)
        b.down_up = ' '.join(b.down_up)
        b.chords = sorted(set(b.chords))

    ordered = [bars[k] for k in sorted(bars)]
    # simple section clustering unchanged for brevity ...

    if return_dataframes:
        return dict(
            tempo_bpm=tempo,
            strums=pd.DataFrame([asdict(s) for s in strums]),
            bars=pd.DataFrame([asdict(b) for b in ordered]),
            chords_timeline=chords_df,
            notes_timeline=notes_df,
        )
    else:
        return dict(
            tempo_bpm=tempo,
            strums=[asdict(s) for s in strums],
            bars=[asdict(b) for b in ordered],
            chords_timeline=chords_df.to_dict('records'),
            notes_timeline=notes_df.to_dict('records'),
        )

# --------------------------------------------------------------------------
# Notebook helper -----------------------------------------------------------

def run_in_notebook(audio_path: str):
    data = analyse_audio(audio_path, return_dataframes=True)
    from IPython.display import display
    print(f"Tempo ≈ {data['tempo_bpm']:.1f} BPM\n")
    print("Chord segments:"); display(data['chords_timeline'].head())
    print("Note events:");   display(data['notes_timeline'].head())
    print("\nStrums (first 10):"); display(data['strums'].head(10))
    print("\nBars:"); display(data['bars'].head())
    return data


In [5]:
audio_path = f'outputs/{song_name}/stage2_guitar_enhanced_cut.wav'
# data = run_in_notebook(audio_path)
data = analyse_audio(audio_path, return_dataframes=True)
return_dataframes = True

  import pkg_resources
  file_sample_rate, signal = wavfile.read(filename, mmap=True)


In [7]:
data['chords_timeline']

Unnamed: 0,start,end,label
0,0.0,0.2,N
1,0.2,4.2,F#:maj
2,4.2,4.9,C#:maj
3,4.9,7.1,A#:maj
4,7.1,9.9,D#:min
5,9.9,11.1,A#:min
6,11.1,15.2,D#:min
7,15.2,17.6,A#:min
8,17.6,20.1,A#:maj
9,20.1,22.3,F#:maj


In [116]:
import json
import pandas as pd
import numpy as np

# Function to convert pandas DataFrames and NumPy arrays to JSON-serializable types
def convert_to_serializable(obj):
    if isinstance(obj, pd.DataFrame):
        return obj.to_dict(orient='records')  # Convert DataFrame to list of dictionaries
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.floating)):
        return float(obj) if isinstance(obj, np.floating) else int(obj)
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list) or isinstance(obj, tuple):
        return [convert_to_serializable(i) for i in obj]
    else:
        return obj

# Convert the data and save to JSON
serializable_data = convert_to_serializable(data)

# Save to file
with open(f'outputs/{song_name}/guitar_data.json', 'w') as f:
    json.dump(serializable_data, f, indent=4)

In [117]:
data.keys()

dict_keys(['tempo_bpm', 'strums', 'bars', 'chords_timeline', 'notes_timeline'])

In [30]:
import numpy as np
import pandas as pd
import torch
import librosa
from typing import List, Dict, Optional
from dataclasses import dataclass
from collections import defaultdict

# ────────────────────────────────────────────────────────────────
# Helper utilities
# ────────────────────────────────────────────────────────────────

def to_numpy(x):
    """Return a NumPy array no matter where the tensor currently lives."""
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)


def ensure_cpu(x):
    """Move a tensor to CPU if it is on an accelerator."""
    if isinstance(x, torch.Tensor):
        return x.detach().cpu()
    return x

# ────────────────────────────────────────────────────────────────
# Data classes
# ────────────────────────────────────────────────────────────────

@dataclass
class GuitarNote:
    time: float
    pitch: float
    string: int
    fret: int
    duration: float
    confidence: float


@dataclass
class GuitarChord:
    time: float
    duration: float
    root: str
    quality: str
    notes: List[GuitarNote]
    fret_positions: List[int]
    chord_name: str

# ────────────────────────────────────────────────────────────────
# Core processing
# ────────────────────────────────────────────────────────────────

def process_fretnet_output_126dim(output: Dict,
                                 confidence_threshold: float = 0.5,
                                 hop_length: int = 512,
                                 sr: int = 22050,
                                 n_frets: int = 21) -> List[GuitarNote]:
    """Convert FretNet raw output (126‑dim tablature) into a list of GuitarNote objects."""

    tablature = output["tablature"]
    onsets = output.get("onsets", None)

    # convert tensors to NumPy – always!
    tablature = to_numpy(tablature)
    if onsets is not None:
        onsets = to_numpy(onsets)

    # Remove batch dim if present -> (frames, 126)
    if tablature.ndim == 3:
        tablature = tablature[0]

    n_frames = tablature.shape[0]
    n_strings = 6

    # reshape to (frames, strings, frets)
    if tablature.shape[1] == n_strings * n_frets:
        tab_3d = tablature.reshape(n_frames, n_strings, n_frets)
        tab_2d = None
    elif tablature.shape[1] == n_strings:  # already condensed (frames × strings)
        tab_2d = tablature
        tab_3d = None
    else:
        print("[process_fretnet_output_126dim] Unexpected tablature shape:", tablature.shape)
        return []

    open_strings = [40, 45, 50, 55, 59, 64]  # MIDI numbers of E2, A2, D3, G3, B3, E4
    notes: List[GuitarNote] = []

    for string_idx in range(n_strings):
        if tab_3d is not None:
            string_activations = tab_3d[:, string_idx, :]
            fret_indices = np.argmax(string_activations, axis=1)
            fret_confidences = np.max(string_activations, axis=1)

            open_string_conf = string_activations[:, 0]
            silence_frames = (
                (fret_confidences < confidence_threshold)
                | ((fret_indices == 0) & (open_string_conf < confidence_threshold))
            )
            active_frames = (~silence_frames)
        else:
            string_frets = tab_2d[:, string_idx]
            active_frames = string_frets >= 0
            fret_indices = string_frets.astype(int)
            fret_confidences = np.full_like(string_frets, 0.8)

        if not active_frames.any():
            continue

        # onset signal may or may not be present – handle flexible dims
        onset_signal = None
        if onsets is not None:
            if onsets.ndim == 2 and string_idx < onsets.shape[1]:
                onset_signal = onsets[:, string_idx]
            elif onsets.ndim == 3 and string_idx < onsets.shape[2]:
                onset_signal = onsets[0, :, string_idx]

        # segment active regions
        changes = np.diff(np.concatenate(([False], active_frames, [False])))
        note_starts = np.where(changes)[0][::2]
        note_ends = np.where(changes)[0][1::2]

        for start, end in zip(note_starts, note_ends):
            if end - start < 2:  # ignore too short
                continue
            segment_frets = fret_indices[start:end]
            segment_conf = fret_confidences[start:end]

            values, counts = np.unique(segment_frets, return_counts=True)
            fret = int(values[np.argmax(counts)])
            if fret < 0 or fret > n_frets:
                continue

            start_time = (start * hop_length) / sr
            duration = ((end - start) * hop_length) / sr
            if duration < 0.05:
                continue

            midi_pitch = open_strings[string_idx] + fret
            confidence = float(np.mean(segment_conf))

            notes.append(GuitarNote(
                time=start_time,
                pitch=midi_pitch,
                string=string_idx,
                fret=fret,
                duration=duration,
                confidence=confidence,
            ))

    return sorted(notes, key=lambda n: n.time)

# ────────────────────────────────────────────────────────────────
# Chord grouping helpers (original logic kept unchanged)
# ────────────────────────────────────────────────────────────────

def identify_chord_type(pitch_classes: List[int]):
    note_names = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
    unique_pcs = sorted(set(pitch_classes))

    chord_patterns = {
        (0, 4, 7): ("maj", "Major"),
        (0, 3, 7): ("min", "Minor"),
        (0, 4, 7, 11): ("maj7", "Major 7"),
        (0, 3, 7, 10): ("min7", "Minor 7"),
        (0, 4, 7, 10): ("7", "Dominant 7"),
        (0, 3, 6): ("dim", "Diminished"),
        (0, 4, 8): ("aug", "Augmented"),
        (0, 5, 7): ("sus4", "Suspended 4"),
        (0, 2, 7): ("sus2", "Suspended 2"),
    }

    for root_pc in unique_pcs:
        intervals = tuple(sorted(((pc - root_pc) % 12) for pc in unique_pcs))
        if intervals in chord_patterns:
            quality, full_name = chord_patterns[intervals]
            root_name = note_names[root_pc]
            return root_name, quality, f"{root_name} {full_name}"

    if unique_pcs:
        root_pc = unique_pcs[0]
        return note_names[root_pc], "unknown", f"{note_names[root_pc]} (unknown)"

    return None, None, None


def analyze_chord(notes: List[GuitarNote]) -> Optional[GuitarChord]:
    pitch_classes = [n.pitch % 12 for n in notes]
    fret_positions = [-1] * 6
    for n in notes:
        fret_positions[n.string] = n.fret

    root, quality, chord_name = identify_chord_type(pitch_classes)
    if root is None:
        return None

    start_time = min(n.time for n in notes)
    end_time = max(n.time + n.duration for n in notes)

    return GuitarChord(
        time=start_time,
        duration=end_time - start_time,
        root=root,
        quality=quality,
        notes=notes,
        fret_positions=fret_positions,
        chord_name=chord_name,
    )


def group_notes_to_chords(notes: List[GuitarNote], time_threshold: float = 0.05, min_notes: int = 2):
    time_groups: Dict[float, List[GuitarNote]] = defaultdict(list)
    for n in notes:
        key = round(n.time / time_threshold) * time_threshold
        time_groups[key].append(n)

    chords: List[GuitarChord] = []
    for t, grp in time_groups.items():
        if len(grp) < min_notes:
            continue
        grp.sort(key=lambda n: n.string)
        chord = analyze_chord(grp)
        if chord:
            chords.append(chord)
    return sorted(chords, key=lambda c: c.time)

# ────────────────────────────────────────────────────────────────
# Tablature helper
# ────────────────────────────────────────────────────────────────

def create_tablature_from_notes(notes: List[GuitarNote], max_time: float = 30.0):
    if not notes:
        return "No notes detected"

    time_slots = defaultdict(lambda: [-1] * 6)
    for n in notes:
        if n.time > max_time:
            break
        slot = int(n.time * 4) / 4  # quarter‑second grid
        time_slots[slot][n.string] = n.fret

    lines = ["e |", "B |", "G |", "D |", "A |", "E |"]
    for t in sorted(time_slots.keys())[:20]:
        frets = time_slots[t]
        for i in range(6):  # display high → low
            fret = frets[5 - i]
            if fret == -1:
                lines[i] += "---"
            elif fret == 0:
                lines[i] += "-0-"
            elif fret < 10:
                lines[i] += f"-{fret}-"
            else:
                lines[i] += f"{fret}-"
    return "\n".join(lines)

# ────────────────────────────────────────────────────────────────
# Main entry point
# ────────────────────────────────────────────────────────────────

def fretnet_transcribe_updated(
    audio_path: str,
    model_path: str,
    confidence_threshold: float = 0.5,
    chord_time_threshold: float = 0.05,
    device: str = "mps",
):
    """High‑level wrapper: load model, run inference, return structured results."""

    from amt_tools.features import HVQT
    from guitar_transcription_continuous.models import FretNet

    # choose device gracefully
    if device == "mps" and not torch.backends.mps.is_available():
        device = "cpu"
    dev = torch.device(device)

    # ─── load model checkpoint ───
    ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
    if isinstance(ckpt, FretNet):
        model = ckpt
    else:
        model = FretNet()
        state_dict = ckpt.get("model_state_dict", ckpt)
        model.load_state_dict(state_dict, strict=False)

    model.eval().to(dev)

    # ─── feature extraction ───
    y, sr = librosa.load(audio_path, sr=22050)
    hvqt = HVQT(
        sample_rate=sr,
        hop_length=512,
        n_bins=144,
        bins_per_octave=36,
        fmin=librosa.note_to_hz("E2"),
    )
    feats = hvqt.process_audio(y)  # (6, 144, T)
    print("HVQT raw:", feats.shape)
    
    # pad frame axis (axis 2) so T % 9 == 0
    frame_axis = 2
    pad = (-feats.shape[frame_axis]) % 9
    if pad:
        feats = np.pad(feats, ((0, 0), (0, 0), (0, pad)), mode="constant")

    feats_tensor = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(dev)  # (1, 6, 144, T′)

    assert feats_tensor.shape[1] == 6
    assert feats_tensor.shape[2] == 144
    assert feats_tensor.shape[3] % 9 == 0, "frame axis must be multiple of 9"

    global output

    # ─── inference ───
    with torch.no_grad():
        output = model(feats_tensor)

    # move tensors in output back to CPU
    
    output = {k: ensure_cpu(v) for k, v in output.items()}

    # ─── post‑processing ───
    notes = process_fretnet_output_126dim(output, confidence_threshold)
    chords = group_notes_to_chords(notes, chord_time_threshold)

    notes_df = pd.DataFrame(
        [
            {
                "time": n.time,
                "pitch": n.pitch,
                "note_name": librosa.midi_to_note(n.pitch),
                "string": n.string,
                "fret": n.fret,
                "duration": n.duration,
                "confidence": n.confidence,
            }
            for n in notes
        ]
    )

    chords_df = pd.DataFrame(
        [
            {
                "time": c.time,
                "duration": c.duration,
                "chord_name": c.chord_name,
                "root": c.root,
                "quality": c.quality,
                "fret_positions": c.fret_positions,
                "num_notes": len(c.notes),
            }
            for c in chords
        ]
    )

    tablature = create_tablature_from_notes(notes)

    return {
        "notes": notes,
        "chords": chords,
        "notes_df": notes_df,
        "chords_df": chords_df,
        "tablature": tablature,
        "raw_output": output,
    }

# ────────────────────────────────────────────────────────────────
# Example usage (comment out if importing as a module)
# ────────────────────────────────────────────────────────────────


audio_path = '/Users/danielcrake/Desktop/Guitar-Separator/testing/outputs/FontainesD.C.-BugOfficialVideo/stage2_guitar_enhanced.wav'
model_path = "/Users/danielcrake/Desktop/Guitar-Separator/FretNet/models/fold-0/model-2000.pt"

results = fretnet_transcribe_updated(audio_path, model_path)

print("\nResults:")
if not results['notes_df'].empty:
    print(f"Found {len(results['notes_df'])} notes")
    print("\nFirst few notes:")
    print(results['notes_df'].head())

if not results['chords_df'].empty:
    print(f"\nFound {len(results['chords_df'])} chords")
    print("\nFirst few chords:")
    print(results['chords_df'].head())

print("\nTablature:")
print(results['tablature'])

HVQT raw: (6, 144, 8843)

Results:
Found 447 notes

First few notes:
   time  pitch note_name  string  fret  duration  confidence
0   0.0     40        E2       0     0  0.743039    9.697382
1   0.0     45        A2       1     0  0.441179    8.104714
2   0.0     50        D3       2     0  0.835918    4.915620
3   0.0     55        G3       3     0  0.650159    3.640748
4   0.0     59        B3       4     0  0.789478    5.761517

Found 121 chords

First few chords:
       time  duration   chord_name root  quality          fret_positions  \
0  0.000000  0.835918  D (unknown)    D  unknown      [0, 0, 0, 0, 0, 0]   
1  0.626939  0.208980  E (unknown)    E  unknown   [-1, 0, -1, 0, -1, 0]   
2  0.835918  0.162540  D (unknown)    D  unknown   [0, -1, 0, -1, -1, 0]   
3  0.975238  0.185760      E Minor    E      min   [0, -1, -1, 0, 0, -1]   
4  2.577415  0.116100  E (unknown)    E  unknown  [0, -1, -1, -1, -1, 0]   

   num_notes  
0          6  
1          3  
2          3  
3          

In [31]:
output

{'tablature': tensor([[[  6.9916, -10.3610, -10.8343,  ..., -16.0457, -15.7000, -16.1509],
          [ 12.9874, -13.7650, -17.0816,  ..., -12.9151, -13.8413, -13.1564],
          [  3.3934,  -7.1001,  -8.7826,  ..., -19.3541, -19.5007, -20.1660],
          ...,
          [ 14.6498, -15.1403, -16.6030,  ..., -20.1774, -20.3374, -20.9411],
          [ 12.4819, -14.7568, -16.2506,  ..., -15.3762, -15.7362, -15.1316],
          [ 10.2246, -14.2621, -18.2183,  ..., -15.4078, -15.1092, -14.8960]]]),
 'tablature_rel': tensor([[[ 1.3033e-03, -9.0692e-05, -9.7414e-04,  ...,  1.9838e-06,
           -5.2828e-06,  2.7949e-05],
          [ 1.3033e-03, -9.0692e-05, -9.7414e-04,  ...,  1.9838e-06,
           -5.2828e-06,  2.7949e-05],
          [ 1.3033e-03, -9.0692e-05, -9.7414e-04,  ...,  1.9838e-06,
           -5.2828e-06,  2.7949e-05],
          ...,
          [ 1.3033e-03, -9.0692e-05, -9.7414e-04,  ...,  1.9838e-06,
           -5.2828e-06,  2.7949e-05],
          [ 1.3033e-03, -9.0692e-05, -9.7

In [32]:
import amt_tools.tools as tools

tools.KEY_NOTES


'notes'

In [1]:
from guitar_transcription_continuous.estimators import StackedPitchListTablatureWrapper
from amt_tools.features import HCQT

from amt_tools.transcribe import ComboEstimator, \
                                 TablatureWrapper, \
                                 StackedOffsetsWrapper, \
                                 StackedNoteTranscriber
from amt_tools.inference import run_offline

import guitar_transcription_continuous.utils as utils
import amt_tools.tools as tools

# Regular imports
import matplotlib.pyplot as plt
import matplotlib
import librosa
import torch
import os


matplotlib.use('TkAgg')

# Define path to model and audio to transcribe
model_path = '/Users/danielcrake/Desktop/Guitar-Separator/FretNet/models/fold-0/model-2000.pt'
audio_path = '/Users/danielcrake/Desktop/Guitar-Separator/testing/outputs/JeffBuckley-LoverYouShouldveComeOverAudio/stage2_guitar_enhanced_cut.wav'

# Number of samples per second of audio
sample_rate = 22050
# Number of samples between frames
hop_length = 512
# Flag to re-acquire ground-truth data and re-calculate features
reset_data = False
# Choose the GPU on which to perform evaluation
gpu_id = 0


# device = torch.device(f'device="mps" if torch.backends.mps.is_available() else "cpu"')

# Load the model
# Initialize a device pointer for loading the model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Load the model
model = torch.load(model_path, map_location=device, weights_only=False)
model = model.to(device)
model.eval()

# CRITICAL FIX: Update internal device properties
model.device = device
if hasattr(model, 'tablature_layer'):
    model.tablature_layer.device = device

# Fix all submodules with device properties
for name, module in model.named_modules():
    if hasattr(module, 'device'):
        module.device = device

print(f"Model device after fix: {model.device}")


##############################
# Predictions                #
##############################

# Load in the audio and normalize it
audio, _ = tools.load_normalize_audio(audio_path, sample_rate)

# Create an HCQT feature extraction module comprising
# the first five harmonics and a sub-harmonic, where each
# harmonic transform spans 4 octaves w/ 3 bins per semitone
data_proc = HCQT(sample_rate=sample_rate,
                 hop_length=hop_length,
                 fmin=librosa.note_to_hz('E2'),
                 harmonics=[0.5, 1, 2, 3, 4, 5],
                 n_bins=144, bins_per_octave=36)

# Compute the features
features = {tools.KEY_FEATS : data_proc.process_audio(audio),
            tools.KEY_TIMES : data_proc.get_times(audio)}

# Initialize the estimation pipeline
estimator = ComboEstimator([
    # Discrete tablature -> stacked multi pitch array
    TablatureWrapper(profile=model.profile),
    # Stacked multi pitch array -> stacked offsets array
    StackedOffsetsWrapper(profile=model.profile),
    # Stacked multi pitch array -> stacked notes
    StackedNoteTranscriber(profile=model.profile),
    # Continuous tablature arrays -> stacked pitch list
    StackedPitchListTablatureWrapper(profile=model.profile,
                                     multi_pitch_key=tools.KEY_TABLATURE,
                                     multi_pitch_rel_key=utils.KEY_TABLATURE_REL)])

# Perform inference offline
predictions = run_offline(features, model, estimator)

# Extract the estimated notes
stacked_notes_est = predictions[tools.KEY_NOTES]

##############################
# Plotting                   #
##############################

# Convert the estimated notes to frets
stacked_frets_est = tools.stacked_notes_to_frets(stacked_notes_est)

# Plot estimated tablature and add an appropriate title
fig_est = tools.initialize_figure(interactive=False, figsize=(20, 5))
fig_est = tools.plot_guitar_tablature(stacked_frets_est, fig=fig_est)
fig_est.suptitle('Inference')

# Display the plot
plt.show(block=True)

  from pkg_resources import resource_filename


Using device: mps
Model device after fix: mps


In [2]:
import numpy as np
from collections import defaultdict, Counter
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import itertools

@dataclass
class ChordResult:
    time_start: float
    time_end: float
    chord_name: str
    root_note: str
    chord_quality: str
    notes: List[str]
    confidence: float
    bass_note: Optional[str] = None

class AdvancedChordRecognizer:
    def __init__(self):
        self.setup_chord_database()
        self.setup_guitar_tuning()
        
    def setup_guitar_tuning(self):
        """Standard guitar tuning (low to high): E-A-D-G-B-E"""
        self.string_tuning = {
            0: 40,  # Low E (E2)
            1: 45,  # A (A2) 
            2: 50,  # D (D3)
            3: 55,  # G (G3)
            4: 59,  # B (B3)
            5: 64   # High E (E4)
        }
        
    def setup_chord_database(self):
        """Comprehensive chord pattern database"""
        self.note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
        
        # Define chord intervals (semitones from root)
        self.chord_patterns = {
            # Triads
            'major': [0, 4, 7],
            'minor': [0, 3, 7],
            'diminished': [0, 3, 6],
            'augmented': [0, 4, 8],
            'sus2': [0, 2, 7],
            'sus4': [0, 5, 7],
            
            # Seventh chords
            'major7': [0, 4, 7, 11],
            'minor7': [0, 3, 7, 10],
            'dominant7': [0, 4, 7, 10],
            'diminished7': [0, 3, 6, 9],
            'half-diminished7': [0, 3, 6, 10],
            'major7#11': [0, 4, 7, 11, 18],
            
            # Extended chords
            'add9': [0, 4, 7, 14],
            'major9': [0, 4, 7, 11, 14],
            'minor9': [0, 3, 7, 10, 14],
            '9': [0, 4, 7, 10, 14],
            '11': [0, 4, 7, 10, 14, 17],
            '13': [0, 4, 7, 10, 14, 21],
            
            # Power chord
            'power': [0, 7],
            '5': [0, 7],  # Same as power chord
        }
        
        # Common guitar chord voicings
        self.guitar_voicings = {
            'C': {'frets': [3, 3, 2, 0, 1, 0], 'notes': ['G', 'C', 'E', 'G', 'C', 'E']},
            'G': {'frets': [3, 2, 0, 0, 3, 3], 'notes': ['G', 'B', 'D', 'G', 'B', 'G']},
            'Am': {'frets': [-1, 0, 2, 2, 1, 0], 'notes': [None, 'A', 'E', 'A', 'C', 'E']},
            'F': {'frets': [1, 3, 3, 2, 1, 1], 'notes': ['F', 'A', 'C', 'F', 'A', 'F']},
            'Dm': {'frets': [-1, -1, 0, 2, 3, 1], 'notes': [None, None, 'D', 'A', 'D', 'F']},
            'Em': {'frets': [0, 2, 2, 0, 0, 0], 'notes': ['E', 'B', 'E', 'G', 'B', 'E']},
        }

    def midi_to_note(self, midi_num: float) -> str:
        """Convert MIDI number to note name"""
        return self.note_names[int(midi_num) % 12]
    
    def note_to_number(self, note: str) -> int:
        """Convert note name to chromatic number (0-11)"""
        return self.note_names.index(note)
    
    def normalize_notes(self, notes: List[str]) -> List[int]:
        """Convert notes to chromatic numbers and remove duplicates"""
        return sorted(list(set([self.note_to_number(note) for note in notes])))
    
    def find_chord_intervals(self, notes: List[int]) -> List[int]:
        """Calculate intervals from the lowest note"""
        if not notes:
            return []
        root = min(notes)
        return [(note - root) % 12 for note in notes]
    
    def match_chord_pattern(self, intervals: List[int]) -> Tuple[str, float]:
        """Match intervals to known chord patterns"""
        intervals_set = set(intervals)
        best_match = ('unknown', 0.0)
        
        for chord_type, pattern in self.chord_patterns.items():
            pattern_set = set(pattern)
            
            # Calculate match score
            if len(pattern_set) == 0:
                continue
                
            intersection = len(intervals_set & pattern_set)
            union = len(intervals_set | pattern_set)
            
            # Jaccard similarity with bonus for exact matches
            if pattern_set == intervals_set:
                score = 1.0
            else:
                score = intersection / union if union > 0 else 0
                
            # Bonus for having the root and fifth
            if 0 in intervals_set and 7 in intervals_set:
                score += 0.1
                
            if score > best_match[1]:
                best_match = (chord_type, score)
        
        return best_match
    
    def find_best_root(self, notes: List[int]) -> Tuple[int, str, float]:
        """Try each note as root and find best chord match"""
        best_result = (0, 'unknown', 0.0)
        
        for root_candidate in notes:
            # Calculate intervals with this root
            intervals = [(note - root_candidate) % 12 for note in notes]
            intervals = sorted(list(set(intervals)))
            
            chord_type, confidence = self.match_chord_pattern(intervals)
            
            if confidence > best_result[2]:
                best_result = (root_candidate, chord_type, confidence)
        
        return best_result
    
    def extract_notes_in_timeframe(self, notes_data: Dict, start_time: float, end_time: float) -> List[Tuple[str, int]]:
        """Extract all notes active in a given timeframe"""
        active_notes = []
        
        for string_id, (pitches, intervals) in notes_data.items():
            if len(intervals) == 0:
                continue
                
            for i, interval in enumerate(intervals):
                note_start, note_end = interval[0], interval[1]
                
                # Check if note overlaps with our timeframe
                if note_start <= end_time and note_end >= start_time:
                    midi_pitch = pitches[i]
                    note_name = self.midi_to_note(midi_pitch)
                    note_number = int(midi_pitch) % 12
                    active_notes.append((note_name, note_number))
        
        return active_notes
    
    def temporal_smoothing(self, chord_sequence: List[ChordResult], min_duration: float = 0.5) -> List[ChordResult]:
        """Remove very short chord changes and smooth sequence"""
        if not chord_sequence:
            return []
            
        smoothed = []
        current_chord = chord_sequence[0]
        
        for next_chord in chord_sequence[1:]:
            duration = current_chord.time_end - current_chord.time_start
            
            # If chord is too short, try to merge with next if they're similar
            if duration < min_duration:
                if (next_chord.root_note == current_chord.root_note and 
                    next_chord.chord_quality == current_chord.chord_quality):
                    # Merge chords
                    current_chord.time_end = next_chord.time_end
                    current_chord.confidence = max(current_chord.confidence, next_chord.confidence)
                    continue
            
            smoothed.append(current_chord)
            current_chord = next_chord
        
        smoothed.append(current_chord)
        return smoothed
    
    def analyze_chord_progression(self, predictions: Dict, 
                                time_resolution: float = 0.25,
                                min_notes: int = 2,
                                confidence_threshold: float = 0.3) -> List[ChordResult]:
        """Main method to analyze chord progression"""
        
        notes_data = predictions['notes']
        times = predictions.get('times', [])
        
        if len(times) == 0:
            return []
        
        # Create time grid
        start_time = float(times[0])
        end_time = float(times[-1])
        time_points = np.arange(start_time, end_time, time_resolution)
        
        chord_results = []
        
        for i in range(len(time_points) - 1):
            window_start = time_points[i]
            window_end = time_points[i + 1]
            
            # Extract notes in this time window
            active_notes = self.extract_notes_in_timeframe(notes_data, window_start, window_end)
            
            if len(active_notes) < min_notes:
                continue
            
            # Get unique note names and numbers
            note_names = [note[0] for note in active_notes]
            note_numbers = [note[1] for note in active_notes]
            unique_notes = list(set(note_numbers))
            
            if len(unique_notes) < min_notes:
                continue
            
            # Find best chord match
            root_number, chord_quality, confidence = self.find_best_root(unique_notes)
            
            if confidence < confidence_threshold:
                continue
            
            root_note = self.note_names[root_number]
            bass_note = self.note_names[min(unique_notes)]  # Lowest note as bass
            
            # Create chord name
            if chord_quality == 'major':
                chord_name = root_note
            elif chord_quality == 'power' or chord_quality == '5':
                chord_name = f"{root_note}5"
            else:
                chord_name = f"{root_note}{chord_quality}"
            
            # Add slash chord notation if bass != root
            if bass_note != root_note and chord_quality not in ['power', '5']:
                chord_name += f"/{bass_note}"
            
            chord_result = ChordResult(
                time_start=window_start,
                time_end=window_end,
                chord_name=chord_name,
                root_note=root_note,
                chord_quality=chord_quality,
                notes=list(set(note_names)),
                confidence=confidence,
                bass_note=bass_note if bass_note != root_note else None
            )
            
            chord_results.append(chord_result)
        
        # Merge consecutive identical chords
        merged_chords = self.merge_consecutive_chords(chord_results)
        
        # Apply temporal smoothing
        smoothed_chords = self.temporal_smoothing(merged_chords)
        
        return smoothed_chords
    
    def merge_consecutive_chords(self, chord_results: List[ChordResult]) -> List[ChordResult]:
        """Merge consecutive identical chords"""
        if not chord_results:
            return []
        
        merged = []
        current = chord_results[0]
        
        for next_chord in chord_results[1:]:
            if (current.chord_name == next_chord.chord_name and 
                abs(current.time_end - next_chord.time_start) < 0.1):  # Small gap tolerance
                # Merge chords
                current.time_end = next_chord.time_end
                current.confidence = max(current.confidence, next_chord.confidence)
                # Combine notes
                current.notes = list(set(current.notes + next_chord.notes))
            else:
                merged.append(current)
                current = next_chord
        
        merged.append(current)
        return merged
    
    def print_chord_analysis(self, chord_results: List[ChordResult], max_chords: int = 20):
        """Print formatted chord analysis"""
        print(f"\n🎸 CHORD ANALYSIS RESULTS")
        print("=" * 60)
        
        if not chord_results:
            print("No chords detected.")
            return
        
        print(f"Found {len(chord_results)} chord segments")
        print(f"Showing first {min(max_chords, len(chord_results))} chords:\n")
        
        for i, chord in enumerate(chord_results[:max_chords]):
            duration = chord.time_end - chord.time_start
            confidence_bar = "█" * int(chord.confidence * 10) + "░" * (10 - int(chord.confidence * 10))
            
            print(f"{i+1:2d}. {chord.time_start:6.1f}s - {chord.time_end:6.1f}s "
                  f"({duration:4.1f}s) | {chord.chord_name:8s} | "
                  f"Confidence: {confidence_bar} {chord.confidence:.2f}")
            print(f"     Notes: {', '.join(sorted(chord.notes))}")
            
            if i > 0 and i % 5 == 0:
                print()
    
    def export_chord_progression(self, chord_results: List[ChordResult]) -> str:
        """Export chord progression as a simple string"""
        if not chord_results:
            return "No chords detected"
        
        # Group by chord name for summary
        chord_sequence = []
        for chord in chord_results:
            duration = chord.time_end - chord.time_start
            if duration >= 0.5:  # Only include chords lasting at least 0.5 seconds
                chord_sequence.append(chord.chord_name)
        
        # Remove consecutive duplicates
        unique_sequence = []
        for chord in chord_sequence:
            if not unique_sequence or chord != unique_sequence[-1]:
                unique_sequence.append(chord)
        
        return " | ".join(unique_sequence)

# Usage example and analysis function
def analyze_guitar_chords(predictions):
    """Analyze guitar chords from prediction data"""
    recognizer = AdvancedChordRecognizer()
    
    print("🔍 Analyzing guitar chords...")
    chord_results = recognizer.analyze_chord_progression(
        predictions, 
        time_resolution=0.25,  # Analyze every 0.25 seconds
        min_notes=2,           # Need at least 2 notes for a chord
        confidence_threshold=0.4  # Minimum confidence for chord detection
    )
    
    # Print analysis
    recognizer.print_chord_analysis(chord_results)
    
    # Export progression
    progression = recognizer.export_chord_progression(chord_results)
    print(f"\n🎵 CHORD PROGRESSION:")
    print(f"   {progression}")
    
    return chord_results, recognizer

# Quick analysis of specific time segments
def quick_chord_check(predictions, time_segments=None):
    """Quick chord analysis for specific time segments"""
    if time_segments is None:
        time_segments = [(10, 15), (18, 23), (45, 50), (80, 85)]
    
    recognizer = AdvancedChordRecognizer()
    
    print("\n🎯 QUICK CHORD CHECK:")
    print("-" * 40)
    
    for start, end in time_segments:
        notes = recognizer.extract_notes_in_timeframe(predictions['notes'], start, end)
        if len(notes) >= 2:
            note_names = [note[0] for note in notes]
            note_numbers = [note[1] for note in notes]
            unique_notes = list(set(note_numbers))
            
            root_number, chord_quality, confidence = recognizer.find_best_root(unique_notes)
            root_note = recognizer.note_names[root_number]
            
            chord_name = f"{root_note}{chord_quality}" if chord_quality != 'major' else root_note
            
            print(f"{start:2d}-{end:2d}s: {chord_name:8s} (conf: {confidence:.2f}) "
                  f"Notes: {', '.join(sorted(set(note_names)))}")
        else:
            print(f"{start:2d}-{end:2d}s: Not enough notes")

print("Advanced chord recognition system ready!")
print("\nUsage:")
print("  chord_results, recognizer = analyze_guitar_chords(predictions)")
print("  quick_chord_check(predictions)")

Advanced chord recognition system ready!

Usage:
  chord_results, recognizer = analyze_guitar_chords(predictions)
  quick_chord_check(predictions)


In [3]:
chord_results, recognizer = analyze_guitar_chords(predictions)
quick_chord_check(predictions)



🔍 Analyzing guitar chords...

🎸 CHORD ANALYSIS RESULTS
Found 17 chord segments
Showing first 17 chords:

 1.   37.2s -   38.0s ( 0.8s) | D5       | Confidence: ███████████ 1.10
     Notes: A, D
 2.   38.8s -   39.8s ( 1.0s) | C5       | Confidence: ███████████ 1.10
     Notes: C, G
 3.   41.0s -   44.0s ( 3.0s) | E5       | Confidence: ███████████ 1.10
     Notes: B, E
 4.   50.8s -   51.8s ( 1.0s) | C5       | Confidence: ███████████ 1.10
     Notes: C, G
 5.   52.2s -   53.0s ( 0.8s) | Baugmented/G | Confidence: ██████░░░░ 0.67
     Notes: B, G
 6.   57.0s -   57.5s ( 0.5s) | E5       | Confidence: ███████████ 1.10
     Notes: B, E

 7.   57.5s -   57.8s ( 0.2s) | D5       | Confidence: ███████████ 1.10
     Notes: A, D
 8.   62.8s -   63.5s ( 0.8s) | Csus2    | Confidence: ███████████ 1.10
     Notes: C, D, G
 9.   63.5s -   63.8s ( 0.2s) | G5       | Confidence: ███████████ 1.10
     Notes: D, G
10.   66.5s -   68.2s ( 1.8s) | Baugmented/G | Confidence: ██████░░░░ 0.67
     Notes: 

In [4]:
import numpy as np
from collections import defaultdict, Counter
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import matplotlib.patches as patches

@dataclass
class FretChord:
    time_start: float
    time_end: float
    fret_pattern: List[int]  # Frets for each string (6 strings, -1 = not played)
    primary_chord: str
    alternative_chords: List[Tuple[str, float]]  # Alternative interpretations with confidence
    notes_played: List[str]
    bass_note: str
    chord_intervals: List[int]
    chord_complexity: str  # 'simple', 'extended', 'complex'

class FlexibleGuitarAnalyzer:
    def __init__(self):
        self.setup_guitar_system()
        self.setup_flexible_chord_theory()
        
    def setup_guitar_system(self):
        """Setup guitar tuning and fretboard knowledge"""
        self.string_names = ['E', 'A', 'D', 'G', 'B', 'E']  # Low to high
        self.string_tuning = [40, 45, 50, 55, 59, 64]  # MIDI numbers for open strings
        self.note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
        
    def setup_flexible_chord_theory(self):
        """Setup comprehensive chord theory for any combination"""
        # More comprehensive chord patterns
        self.chord_intervals = {
            # Basic triads
            'major': [0, 4, 7],
            'minor': [0, 3, 7],
            'diminished': [0, 3, 6],
            'augmented': [0, 4, 8],
            'sus2': [0, 2, 7],
            'sus4': [0, 5, 7],
            
            # Seventh chords
            'major7': [0, 4, 7, 11],
            'minor7': [0, 3, 7, 10],
            'dominant7': [0, 4, 7, 10],
            'minor7b5': [0, 3, 6, 10],
            'diminished7': [0, 3, 6, 9],
            'major7#11': [0, 4, 7, 11, 18],
            'minor/major7': [0, 3, 7, 11],
            
            # Extensions
            'add9': [0, 4, 7, 14],
            'add11': [0, 4, 7, 17],
            'major9': [0, 4, 7, 11, 14],
            'minor9': [0, 3, 7, 10, 14],
            'dominant9': [0, 4, 7, 10, 14],
            'major11': [0, 4, 7, 11, 14, 17],
            'minor11': [0, 3, 7, 10, 14, 17],
            'dominant11': [0, 4, 7, 10, 14, 17],
            'major13': [0, 4, 7, 11, 14, 21],
            'minor13': [0, 3, 7, 10, 14, 21],
            'dominant13': [0, 4, 7, 10, 14, 21],
            
            # Altered dominants
            '7b5': [0, 4, 6, 10],
            '7#5': [0, 4, 8, 10],
            '7b9': [0, 4, 7, 10, 13],
            '7#9': [0, 4, 7, 10, 15],
            '7#11': [0, 4, 7, 10, 18],
            '7b13': [0, 4, 7, 10, 20],
            
            # Power chords and dyads
            'power': [0, 7],
            '5': [0, 7],
            'octave': [0, 12],
            'fourth': [0, 5],
            'fifth': [0, 7],
            
            # Complex/ambiguous
            '6': [0, 4, 7, 9],
            'minor6': [0, 3, 7, 9],
            '6/9': [0, 4, 7, 9, 14],
            'minor6/9': [0, 3, 7, 9, 14],
        }
        
        # Quality descriptors
        self.chord_qualities = {
            'major': 'Major',
            'minor': 'Minor', 
            'diminished': 'Diminished',
            'augmented': 'Augmented',
            'power': 'Power Chord',
            '5': 'Power Chord'
        }
    
    def fret_to_midi(self, string_idx: int, fret: int) -> int:
        """Convert string and fret to MIDI note number"""
        return self.string_tuning[string_idx] + fret
    
    def midi_to_note(self, midi_num: int) -> str:
        """Convert MIDI number to note name"""
        return self.note_names[midi_num % 12]
    
    def frets_to_notes_and_midi(self, fret_pattern: List[int]) -> Tuple[List[str], List[int]]:
        """Convert fret pattern to note names and MIDI numbers"""
        notes = []
        midi_notes = []
        
        for string_idx, fret in enumerate(fret_pattern):
            if fret >= 0:  # -1 means not played
                midi_note = self.fret_to_midi(string_idx, fret)
                note_name = self.midi_to_note(midi_note)
                notes.append(note_name)
                midi_notes.append(midi_note)
        
        return notes, midi_notes
    
    def analyze_chord_from_any_pattern(self, fret_pattern: List[int]) -> Tuple[str, List[Tuple[str, float]], List[int]]:
        """Analyze any fret pattern to find possible chord interpretations"""
        notes, midi_notes = self.frets_to_notes_and_midi(fret_pattern)
        
        if len(notes) < 2:
            return "Single Note", [], []
        
        # Get unique note classes (ignore octaves)
        note_classes = sorted(list(set([midi % 12 for midi in midi_notes])))
        
        if len(note_classes) < 2:
            return "Unison", [], note_classes
        
        # Try each note as potential root
        possible_chords = []
        
        for root_candidate in note_classes:
            intervals = [(note - root_candidate) % 12 for note in note_classes]
            intervals = sorted(list(set(intervals)))
            
            # Match against all chord patterns
            for chord_type, pattern in self.chord_intervals.items():
                score = self.calculate_chord_match_score(intervals, pattern)
                
                if score > 0.5:  # Only consider reasonable matches
                    root_note = self.note_names[root_candidate]
                    chord_name = self.format_chord_name(root_note, chord_type)
                    possible_chords.append((chord_name, score))
        
        # Sort by confidence
        possible_chords.sort(key=lambda x: x[1], reverse=True)
        
        # Determine primary chord
        if possible_chords:
            primary_chord = possible_chords[0][0]
            alternatives = possible_chords[1:5]  # Top 5 alternatives
        else:
            # Generate descriptive name for unrecognized patterns
            primary_chord = self.generate_descriptive_name(note_classes, notes)
            alternatives = []
        
        return primary_chord, alternatives, note_classes
    
    def calculate_chord_match_score(self, intervals: List[int], pattern: List[int]) -> float:
        """Calculate how well intervals match a chord pattern"""
        intervals_set = set(intervals)
        pattern_set = set(pattern)
        
        # Essential intervals that must be present
        essential = {0}  # Root is essential
        if 7 in pattern_set:  # Fifth is important for most chords
            essential.add(7)
        if 3 in pattern_set or 4 in pattern_set:  # Third defines major/minor
            essential.update({3, 4})
        
        # Check if essential intervals are present
        essential_score = len(essential & intervals_set) / len(essential) if essential else 1.0
        
        if essential_score < 0.5:  # Must have most essential intervals
            return 0.0
        
        # Calculate overall match
        if pattern_set <= intervals_set:  # All pattern notes present
            coverage = len(pattern_set) / len(intervals_set)
            return coverage * essential_score
        
        # Partial match
        intersection = len(intervals_set & pattern_set)
        union = len(intervals_set | pattern_set)
        jaccard = intersection / union if union > 0 else 0
        
        return jaccard * essential_score * 0.8  # Penalize partial matches
    
    def format_chord_name(self, root: str, chord_type: str) -> str:
        """Format chord name properly"""
        if chord_type == 'major':
            return root
        elif chord_type in ['power', '5']:
            return f"{root}5"
        elif chord_type == 'dominant7':
            return f"{root}7"
        elif chord_type == 'major7':
            return f"{root}maj7"
        elif chord_type == 'minor7':
            return f"{root}m7"
        elif chord_type == 'minor':
            return f"{root}m"
        elif chord_type == 'diminished':
            return f"{root}°"
        elif chord_type == 'augmented':
            return f"{root}+"
        elif chord_type == 'diminished7':
            return f"{root}°7"
        elif chord_type == 'minor7b5':
            return f"{root}m7♭5"
        else:
            return f"{root}{chord_type}"
    
    def generate_descriptive_name(self, note_classes: List[int], note_names: List[str]) -> str:
        """Generate descriptive name for unrecognized patterns"""
        if len(note_classes) == 2:
            interval = (note_classes[1] - note_classes[0]) % 12
            interval_names = {
                1: "minor 2nd", 2: "major 2nd", 3: "minor 3rd", 4: "major 3rd",
                5: "perfect 4th", 6: "tritone", 7: "perfect 5th", 8: "minor 6th",
                9: "major 6th", 10: "minor 7th", 11: "major 7th"
            }
            root = self.note_names[note_classes[0]]
            return f"{root} {interval_names.get(interval, 'interval')}"
        
        # For complex patterns, just list the notes
        unique_notes = sorted(list(set(note_names)))
        if len(unique_notes) <= 4:
            return f"{'/'.join(unique_notes)} cluster"
        else:
            return f"Complex chord ({len(unique_notes)} notes)"
    
    def extract_chord_shapes_flexible(self, predictions: Dict, time_resolution: float = 0.8) -> List[FretChord]:
        """Extract chord shapes with flexible analysis"""
        tablature = predictions['tablature']
        times = predictions.get('times', [])
        
        if len(times) == 0:
            return []
        
        chord_shapes = []
        time_step = int(len(times) * time_resolution / (times[-1] - times[0]))
        
        for i in range(0, len(times) - time_step, time_step // 2):  # 50% overlap
            start_idx = i
            end_idx = min(i + time_step, len(times) - 1)
            
            # Collect all fret patterns in this window
            patterns_in_window = []
            
            for time_idx in range(start_idx, end_idx):
                if time_idx < tablature.shape[1]:
                    pattern = [int(tablature[string][time_idx]) for string in range(6)]
                    # Only consider if at least 2 strings are active
                    if sum(1 for f in pattern if f >= 0) >= 2:
                        patterns_in_window.append(tuple(pattern))
            
            if not patterns_in_window:
                continue
            
            # Find most common pattern (or representative pattern)
            pattern_counts = Counter(patterns_in_window)
            most_common = pattern_counts.most_common(1)[0]
            representative_pattern = list(most_common[0])
            pattern_stability = most_common[1] / len(patterns_in_window)
            
            # Analyze the pattern
            primary_chord, alternatives, intervals = self.analyze_chord_from_any_pattern(representative_pattern)
            notes, midi_notes = self.frets_to_notes_and_midi(representative_pattern)
            
            # Determine bass note (lowest pitch)
            bass_note = self.midi_to_note(min(midi_notes)) if midi_notes else ""
            
            # Classify complexity
            complexity = self.classify_chord_complexity(len(set(notes)), intervals)
            
            chord_shape = FretChord(
                time_start=times[start_idx],
                time_end=times[end_idx],
                fret_pattern=representative_pattern,
                primary_chord=primary_chord,
                alternative_chords=alternatives,
                notes_played=notes,
                bass_note=bass_note,
                chord_intervals=intervals,
                chord_complexity=complexity
            )
            
            chord_shapes.append(chord_shape)
        
        return self.merge_similar_flexible_chords(chord_shapes)
    
    def classify_chord_complexity(self, num_unique_notes: int, intervals: List[int]) -> str:
        """Classify chord complexity"""
        if num_unique_notes <= 2:
            return "simple"
        elif num_unique_notes <= 4 and len(intervals) <= 4:
            return "standard"
        elif num_unique_notes <= 5:
            return "extended"
        else:
            return "complex"
    
    def merge_similar_flexible_chords(self, chord_shapes: List[FretChord]) -> List[FretChord]:
        """Merge similar chords with flexible criteria"""
        if not chord_shapes:
            return []
        
        merged = []
        current = chord_shapes[0]
        
        for next_chord in chord_shapes[1:]:
            # More flexible merging criteria
            same_primary = current.primary_chord == next_chord.primary_chord
            similar_pattern = self.patterns_similar(current.fret_pattern, next_chord.fret_pattern)
            time_gap = abs(current.time_end - next_chord.time_start)
            
            if (same_primary or similar_pattern) and time_gap < 0.5:
                # Merge chords
                current.time_end = next_chord.time_end
                # Combine alternative interpretations
                combined_alts = current.alternative_chords + next_chord.alternative_chords
                current.alternative_chords = sorted(list(set(combined_alts)), 
                                                  key=lambda x: x[1], reverse=True)[:3]
            else:
                merged.append(current)
                current = next_chord
        
        merged.append(current)
        return merged
    
    def patterns_similar(self, pattern1: List[int], pattern2: List[int], threshold: float = 0.7) -> bool:
        """Check if two fret patterns are similar"""
        if len(pattern1) != len(pattern2):
            return False
        
        matches = sum(1 for p1, p2 in zip(pattern1, pattern2) 
                     if (p1 == p2) or (p1 < 0 and p2 < 0))
        total_strings = len(pattern1)
        
        return (matches / total_strings) >= threshold
    
    def create_flexible_chord_diagram(self, fret_pattern: List[int], chord_info: FretChord, 
                                    figsize: Tuple[float, float] = (4, 5)) -> plt.Figure:
        """Create chord diagram with additional analysis info"""
        fig, (ax_chord, ax_info) = plt.subplots(2, 1, figsize=figsize, 
                                               gridspec_kw={'height_ratios': [4, 1]})
        
        # Draw main chord diagram
        self._draw_chord_diagram_in_ax(ax_chord, fret_pattern, chord_info.primary_chord)
        
        # Add analysis info
        ax_info.axis('off')
        info_text = f"Bass: {chord_info.bass_note}\n"
        info_text += f"Notes: {', '.join(chord_info.notes_played)}\n"
        info_text += f"Complexity: {chord_info.chord_complexity}\n"
        
        if chord_info.alternative_chords:
            alt_names = [alt[0] for alt in chord_info.alternative_chords[:2]]
            info_text += f"Alternatives: {', '.join(alt_names)}"
        
        ax_info.text(0.5, 0.5, info_text, ha='center', va='center', 
                    fontsize=10, transform=ax_info.transAxes)
        
        plt.tight_layout()
        return fig
    
    def _draw_chord_diagram_in_ax(self, ax, fret_pattern: List[int], chord_name: str):
        """Draw chord diagram in existing axes"""
        # Determine fret range to show
        active_frets = [f for f in fret_pattern if f >= 0]
        if not active_frets:
            num_frets = 4
            start_fret = 0
        else:
            min_fret = min(active_frets)
            max_fret = max(active_frets)
            
            if min_fret == 0 or max_fret <= 4:
                start_fret = 0
                num_frets = max(4, max_fret) + 1
            else:
                start_fret = max(0, min_fret - 1)
                num_frets = max_fret - start_fret + 2
        
        # Draw frets
        for fret in range(num_frets):
            y = num_frets - fret - 1
            line_width = 3 if fret + start_fret == 0 else 1
            ax.axhline(y=y, color='black', linewidth=line_width)
        
        # Draw strings
        for string in range(6):
            ax.axvline(x=string, color='black', linewidth=2)
        
        # Add fret numbers
        for fret in range(num_frets):
            actual_fret = fret + start_fret
            if actual_fret > 0:
                ax.text(-0.7, num_frets - fret - 1.5, str(actual_fret), 
                       ha='center', va='center', fontsize=8)
        
        # Add string names
        string_names = ['E', 'A', 'D', 'G', 'B', 'E']
        for string in range(6):
            ax.text(string, num_frets + 0.3, string_names[string], 
                   ha='center', va='center', fontweight='bold', fontsize=10)
        
        # Draw finger positions
        for string, fret in enumerate(fret_pattern):
            if fret >= 0:
                if start_fret == 0 and fret == 0:
                    # Open string
                    circle = plt.Circle((string, num_frets + 0.15), 0.12, 
                                      color='white', ec='black', linewidth=2)
                    ax.add_patch(circle)
                else:
                    # Fretted note
                    fret_pos = fret - start_fret
                    if 0 <= fret_pos < num_frets:
                        circle = plt.Circle((string, num_frets - fret_pos - 0.5), 0.15, 
                                          color='red', ec='black', linewidth=2)
                        ax.add_patch(circle)
            else:
                # Muted string
                ax.text(string, num_frets + 0.15, 'X', ha='center', va='center', 
                       fontweight='bold', fontsize=12, color='red')
        
        ax.set_xlim(-1, 6)
        ax.set_ylim(-0.8, num_frets + 0.8)
        ax.set_aspect('equal')
        ax.axis('off')
        
        # Add chord name
        ax.text(2.5, -0.5, chord_name, ha='center', va='center', 
               fontweight='bold', fontsize=14)
    
    def print_flexible_analysis(self, chord_shapes: List[FretChord], max_chords: int = 15):
        """Print comprehensive flexible analysis"""
        print(f"\n🎸 FLEXIBLE GUITAR ANALYSIS")
        print("=" * 80)
        
        if not chord_shapes:
            print("No chord patterns detected.")
            return
        
        print(f"Detected {len(chord_shapes)} distinct chord patterns")
        print(f"Showing first {min(max_chords, len(chord_shapes))} patterns:\n")
        
        string_names = ['E', 'A', 'D', 'G', 'B', 'E']
        
        for i, chord in enumerate(chord_shapes[:max_chords]):
            duration = chord.time_end - chord.time_start
            
            print(f"{i+1:2d}. {chord.time_start:6.1f}s - {chord.time_end:6.1f}s "
                  f"({duration:4.1f}s) | {chord.primary_chord:12s} | "
                  f"{chord.chord_complexity}")
            
            # Show fret pattern
            fret_display = []
            for j, fret in enumerate(chord.fret_pattern):
                if fret >= 0:
                    fret_display.append(f"{string_names[j]}:{fret}")
                else:
                    fret_display.append(f"{string_names[j]}:X")
            
            print(f"     Frets: {' | '.join(fret_display)}")
            print(f"     Notes: {', '.join(chord.notes_played)} (Bass: {chord.bass_note})")
            
            # Show alternatives if any
            if chord.alternative_chords:
                alts = [f"{alt[0]} ({alt[1]:.2f})" for alt in chord.alternative_chords[:2]]
                print(f"     Also could be: {', '.join(alts)}")
            
            print()

# Main analysis function
def analyze_flexible_guitar_patterns(predictions):
    """Flexible analysis that works with any fret patterns"""
    analyzer = FlexibleGuitarAnalyzer()
    
    print("🔍 Analyzing guitar patterns with flexible recognition...")
    
    # Extract patterns without rigid shape matching
    chord_shapes = analyzer.extract_chord_shapes_flexible(predictions, time_resolution=1.2)
    
    # Print comprehensive analysis
    analyzer.print_flexible_analysis(chord_shapes)
    
    # Create progression summary
    progression_with_alts = []
    for chord in chord_shapes:
        main = chord.primary_chord
        if chord.alternative_chords and chord.alternative_chords[0][1] > 0.8:
            main += f"/{chord.alternative_chords[0][0]}"
        progression_with_alts.append(main)
    
    print(f"\n🎵 DETECTED PROGRESSION:")
    print(f"   {' | '.join(progression_with_alts)}")
    
    return chord_shapes, analyzer

# Quick pattern check
def analyze_specific_fret_pattern(fret_pattern: List[int]):
    """Analyze any specific fret pattern"""
    analyzer = FlexibleGuitarAnalyzer()
    
    print(f"\n🎯 ANALYZING PATTERN: {fret_pattern}")
    print("-" * 50)
    
    primary, alternatives, intervals = analyzer.analyze_chord_from_any_pattern(fret_pattern)
    notes, midi_notes = analyzer.frets_to_notes_and_midi(fret_pattern)
    
    print(f"Primary interpretation: {primary}")
    print(f"Notes: {', '.join(notes)}")
    print(f"Intervals: {intervals}")
    
    if alternatives:
        print(f"\nAlternative interpretations:")
        for alt_name, confidence in alternatives[:3]:
            print(f"  {alt_name} (confidence: {confidence:.2f})")
    
    return primary, alternatives

print("Flexible guitar analysis system ready!")
print("This system analyzes ANY fret pattern without rigid shape constraints")
print("\nUsage:")
print("  patterns, analyzer = analyze_flexible_guitar_patterns(predictions)")
print("  analyze_specific_fret_pattern([3, 2, 0, 0, 3, 3])  # Test any pattern")

Flexible guitar analysis system ready!
This system analyzes ANY fret pattern without rigid shape constraints

Usage:
  patterns, analyzer = analyze_flexible_guitar_patterns(predictions)
  analyze_specific_fret_pattern([3, 2, 0, 0, 3, 3])  # Test any pattern


In [5]:
patterns, analyzer = analyze_flexible_guitar_patterns(predictions)

🔍 Analyzing guitar patterns with flexible recognition...

🎸 FLEXIBLE GUITAR ANALYSIS
Detected 114 distinct chord patterns
Showing first 15 patterns:

 1.    0.0s -    1.2s ( 1.2s) | C5           | simple
     Frets: E:3 | A:3 | D:X | G:X | B:X | E:X
     Notes: G, C (Bass: G)
     Also could be: C5 (1.00), Cfifth (1.00)

 2.    0.6s -    1.8s ( 1.2s) | C5           | standard
     Frets: E:3 | A:3 | D:X | G:2 | B:X | E:X
     Notes: G, C, A (Bass: G)
     Also could be: C5 (0.67), Cfifth (0.67)

 3.    1.2s -    2.3s ( 1.2s) | C major 3rd  | simple
     Frets: E:0 | A:3 | D:X | G:X | B:X | E:X
     Notes: E, C (Bass: E)

 4.    1.7s -    2.9s ( 1.2s) | C5           | simple
     Frets: E:3 | A:3 | D:X | G:X | B:X | E:X
     Notes: G, C (Bass: G)
     Also could be: C5 (1.00), Cfifth (1.00)

 5.    2.3s -    3.5s ( 1.2s) | C5           | simple
     Frets: E:3 | A:3 | D:X | G:X | B:X | E:X
     Notes: G, C (Bass: G)
     Also could be: C5 (1.00), Cfifth (1.00)

 6.    7.0s -    8.2s ( 1