<a href="https://colab.research.google.com/github/kefoto/CNN/blob/main/14-Langfuse-connection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Audio Inference with Sep28k and Langfuse and designed prompt
> Using Qwen2.5-audio-7B-Instruct for audio analysis & Mistral-7B-Instruct-v0.3 and instructor wrapper for structured output

require 30GB of VRAM since Mistral-7B requires 15 GB of Vram

In this notebook, we read from the Sep28k dataset stored on the `vanderbilt-dsi` Huggingface account. We leverage 2 audio examples and ask a few questions about the audio contained. Below are the results of this exploration.

See [Issue #11](https://github.com/vanderbilt-data-science/stutter-models/issues/11) for more information on setting up your Huggingface Token and adding this in Google Colab.

In [1]:
!pip -q install --upgrade pip
!pip -q install "transformers>=4.43" accelerate einops librosa soundfile sentencepiece torchaudio torch torchcodec==0.7 langfuse
!pip -q install pydantic instructor

In [2]:
from io import BytesIO
from urllib.request import urlopen
from pathlib import Path
import json
import time

import librosa
import numpy as np
import torch
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor, Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
import instructor

import os
from google.colab import userdata

MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"

# Repro-ish (note: generation still has stochasticity unless you fix sampling params)
torch.manual_seed(42)

<torch._C.Generator at 0x7ee5fdfa7050>

## Connecting Langfuse

In [3]:
from langfuse import Langfuse
os.environ['LC_PK'] = userdata.get('LC_PK')
os.environ['LC_SK'] = userdata.get('LC_SK')
os.environ['LC_H'] = userdata.get('LC_H')
langfuse = Langfuse(
    public_key=os.environ['LC_PK'],
    secret_key=os.environ['LC_SK'],
    host=os.environ['LC_H']
)

In [4]:
print("Connected:", langfuse)

Connected: <langfuse._client.client.Langfuse object at 0x7ee4547479e0>


## Downloading Model and Dataset

In [5]:
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')

processor = AutoProcessor.from_pretrained(MODEL_ID)
# model = Qwen2_5OmniForConditionalGeneration.from_pretrained(MODEL_ID,
#                                                            device_map="auto", dtype=torch.bfloat16,)
model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_ID,
                                                           device_map="auto",
                                                           dtype=torch.bfloat16,)

sr = processor.feature_extractor.sampling_rate
model.eval()

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Qwen2AudioForConditionalGeneration(
  (audio_tower): Qwen2AudioEncoder(
    (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(1500, 1280)
    (layers): ModuleList(
      (0-31): 32 x Qwen2AudioEncoderLayer(
        (self_attn): Qwen2AudioAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=1280, out_features=5120, bias=True)
        (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        (final_layer_norm): LayerNorm

In [6]:
chat_model_id = "mistralai/Mistral-7B-Instruct-v0.3"

chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_id)
chat_model = AutoModelForCausalLM.from_pretrained(
    chat_model_id,
    device_map="auto",
    dtype=torch.bfloat16
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
import time
from typing import List, Dict, Any

#define OpenAi api styled class for instructor output
class MockMessage:
    def __init__(self, content: str, role: str = "assistant"):
        self.role = role
        self.content = content

class MockChoice:
    def __init__(self, text: str, index: int = 0):
        self.index = index
        self.message = MockMessage(content=text)
        self.finish_reason = "stop"

class MockCompletion:
    def __init__(
        self,
        text: str,
        model: str = "local-llm",
        prompt_tokens: int = 0,
        completion_tokens: int = 0,
    ):
        self.id = f"chatcmpl-local-{int(time.time())}"
        self.object = "chat.completion"
        self.created = int(time.time())
        self.model = model
        self.choices: List[MockChoice] = [MockChoice(text)]
        self.usage: Dict[str, Any] = {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": prompt_tokens + completion_tokens,
        }

In [18]:
import re

# reponse string sanitization helper
def strip_code_fences(text: str) -> str:
    """Remove Markdown ```json ... ``` or ``` blocks."""
    # Remove triple backticks and language hints
    text = re.sub(r"^```(?:json|JSON)?", "", text.strip())
    text = re.sub(r"```$", "", text.strip())
    return text.strip()

In [21]:
generator = pipeline("text-generation", model=chat_model, tokenizer=chat_tokenizer)

def mistral_create(messages, **kwargs):
    prompt = "\n".join([m["content"] for m in messages])
    out = generator(prompt)[0]["generated_text"]
    out = strip_code_fences(out)
    return MockCompletion(
        text=out,
        model="mistralai/Mistral-7B-Instruct-v0.3",
    )


mistral_structured = instructor.patch(
    create=mistral_create,
    mode=instructor.Mode.JSON_SCHEMA,
)

Device set to use cuda:0


In [9]:
dataset = load_dataset("vanderbilt-dsi/sep-28k-extended", token=os.getenv("HF_TOKEN"))

In [10]:
system_prompt = """
You are a multimodal model trained to detect stuttering disfluencies in both adults and children.
You analyze the provided audio waveform and its transcript to identify disfluencies with high precision.

Task:
Detect and label any stuttering disfluencies in the input audio‚Äìtext pair.
Use both acoustic cues (pauses, prolongations, effort) and linguistic patterns (repetition, interjection, revision).

Definitions:
- Repetition: Unintentional repeating of sounds, syllables, or words (e.g., "b-b-ball").
- Prolongation: A sound held longer than normal (e.g., "ssssun").
- Block: A stoppage of airflow or voicing before or during speech (e.g., "‚Äî‚Äîbook").

Output format rules:
- Output **only** valid JSON ‚Äî no extra text, explanations, or comments.
- Use **double quotes** around all keys and string values.
- Output must be a JSON **array** of objects, not Python-style dicts.
- Numeric fields must be valid floats between 0 and 1 for confidence_score.
- Each object must follow this exact structure:

[
  {
    "time_start": <float>,
    "time_end": <float>,
    "transcript_segment": "<text>",
    "disfluency_type": "<repetition | prolongation | block>",
    "confidence_score": <float 0-1>,
    "commentary": "<optional reasoning or null>"
  }
]

If no disfluencies are detected, output `[]` (an empty JSON array) and nothing else.
"""


question_prompt = """
Analyze the following audio and transcript.
Identify all stuttering disfluencies and return them in the required JSON format.
"""


In [28]:
from pydantic import BaseModel, Field, RootModel
from typing import List, Optional

#for local instruct or pydantic
class Disfluency(BaseModel):
    time_start: float
    time_end: float
    transcript_segment: str
    disfluency_type: str = Field(..., pattern="repetition|prolongation|block")
    confidence_score: float = Field(..., ge=0, le=1)
    commentary: Optional[str] = None

# class DisfluencyResults(RootModel[List[Disfluency]]):
#     pass
class DisfluencyResults(BaseModel):
    results: List[Disfluency]

In [12]:
#for function calling: more available through remote client apis
schema = {
    "name": "DisfluencyResults",
    "schema": {
        "type": "object",
        "properties": {
            "results": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "time_start": {"type": "number"},
                        "time_end": {"type": "number"},
                        "transcript_segment": {"type": "string"},
                        "disfluency_type": {"type": "string", "enum": ["repetition", "prolongation", "block"]},
                        "confidence_score": {"type": "number"},
                        "commentary": {"type": "string"}
                    },
                    "required": ["time_start", "time_end", "transcript_segment", "disfluency_type", "confidence_score"]
                }
            }
        },
        "required": ["results"]
    }
}

## Define langfuse observation

In [23]:
import json
from pydantic import ValidationError
from instructor.exceptions import InstructorRetryException

def safe_structured_parse(messages, response_model):
    try:
        structured = mistral_structured(
            messages=messages,
            response_model=response_model,
        )
        return structured

    except InstructorRetryException as e:
        print("\n‚ö†Ô∏è InstructorRetryException triggered!")
        print(f"Exception type: {type(e).__name__}")
        print(f"Raw exception message:\n{e}\n")
        if hasattr(e, "args") and e.args:
            print("Args passed to exception:", e.args)
        raise  # re-raise for now; remove this line if you want to continue execution

    except ValidationError as e:
        print("\n‚ùå Pydantic ValidationError:")
        print(e)
        # Try to extract invalid JSON from the input value
        try:
            invalid_json = e.errors()[0].get("input_value", "")
            print("\n--- Invalid JSON string ---")
            print(invalid_json[:500])  # print a snippet if long
            print("--- end ---\n")
            # Optionally save to a temp file for inspection
            with open("failed_output.json", "w", encoding="utf-8") as f:
                f.write(invalid_json)
                print("üßæ Saved problematic output to failed_output.json")
        except Exception as inner:
            print("Couldn't extract invalid JSON:", inner)
        raise  # optional: comment out if you want graceful fallback

    except json.JSONDecodeError as e:
        print("\n‚ùå JSONDecodeError:")
        print(f"Error message: {e}")
        print(f"Line {e.lineno}, Column {e.colno}")
        raise

    except Exception as e:
        print("\nüö® Unexpected error in structured parse:")
        print(f"Type: {type(e).__name__}")
        print(f"Message: {e}")
        raise


In [26]:
from langfuse import observe
import time, torch, json, instructor
from pydantic import ValidationError

@observe(name="Qwen2Audio-stutter-pipeline")
def run_stuttering_detection_pipeline(model, processor, sample, system_prompt, question_prompt):
    """
    Full inference pipeline: takes audio sample and prompts,
    runs Qwen2Audio model, and returns JSON output.
    """
    # Build chat structure
    conversation = [
        {'role': 'system', 'content': system_prompt},
        {"role": "user", "content": [
            {"type": "audio", "audio": sample},
            {"type": "text", "text": question_prompt},
        ]},
    ]

    # Step 1: Prepare text + audio
    text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    audios = [sample["array"]]

    # Step 2: Tokenize & prepare tensors
    inputs = processor(text=text, audio=audios, sampling_rate=sr, return_tensors="pt", padding=True)
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    # print(f"#audio segments: {len(audios)} | input_ids shape: {inputs['input_ids'].shape}")

    # Step 3: Inference
    start = time.time()
    with torch.no_grad():
        generate_ids = model.generate(**inputs, max_length=1024)
    elapsed = time.time() - start

    # Step 4: Decode
    text_id = generate_ids
    response_ids = text_id[:, inputs["input_ids"].size(1):]
    response = processor.batch_decode(
        response_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0]

    print(response)

    structured = safe_structured_parse(
      messages=[
        {"role": "system", "content": (
            "You are a strict JSON emitter. "
            "Output strictly valid JSON. Use double quotes for all keys and string values. "
            "Do not wrap output in code fences or Python-style dicts. "
            "Return a JSON array of Disfluency objects. If no disfluencies are detected, return an empty array `[]`."
        )},
        {"role": "user", "content": f"Convert this into valid DisfluencyResults JSON:\n{response}"}
    ],
      response_model=DisfluencyResults,
    )



    print(f"‚è±Ô∏è Generation took {elapsed:.2f}s")
    return structured

## define visualization method

In [14]:
# @title
from IPython.display import Audio, display
import matplotlib.pyplot as plt
import torch
import numpy as np

def play_audio_via_get_all_samples(dataset, index=0, plot=True, play=True):
    """
    Decode, play, and visualize an audio clip using torchcodec's get_all_samples().
    Displays waveform (top) and spectrogram (bottom).

    Assumptions:
      dataset is Huggingface object with audio column as decoded Torchcodec object
      Torchcodec 0.7

    """
    # Handle dataset split
    if hasattr(dataset, "keys") and "train" in dataset:
        sample = dataset["train"][index]
    else:
        sample = dataset[index]

    audio = sample["audio"]

    # Decode via torchcodec
    samples = audio.get_all_samples()
    waveform = samples.data
    sr = samples.sample_rate

    if isinstance(waveform, torch.Tensor):
        waveform = waveform.squeeze().cpu().numpy()

    if plot:
        # Prepare figure
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

        # --- Waveform ---
        ax1.plot(np.arange(len(waveform)) / sr, waveform, linewidth=0.8)
        ax1.set_title(f"Waveform (index={index}, {sr} Hz)")
        ax1.set_xlabel("Time (s)")
        ax1.set_ylabel("Amplitude")
        ax1.grid(True, alpha=0.3)

        # --- Spectrogram ---
        Pxx, freqs, bins, im = ax2.specgram(
            waveform,
            NFFT=1024,
            Fs=sr,
            noverlap=512,
            cmap="magma"
        )
        ax2.set_title("Spectrogram")
        ax2.set_xlabel("Time (s)")
        ax2.set_ylabel("Frequency (Hz)")
        # fig.colorbar(im, ax=ax2, label="Intensity (dB)")

        # plt.legend(None)
        plt.tight_layout()
        plt.show()

    # Play in notebook
    if play:
        display(Audio(waveform, rate=sr))

    return waveform, sr

## Call function

In [30]:
response = "[{\"time_start\": \"0.58\", \"time_end\": \"1.24\", \"transcript_segment\": \"myself\", \"disfluency_type\": \"stuttering\", \"confidence_score\": \"0.7\", \"commentary\": \"Short stuttering hesitation.\"}, {\"time_start\": \"1.63\", \"time_end\": \"2.09\", \"transcript_segment\": \"limiting\", \"disfluency_type\": \"stuttering\", \"confidence_score\": \"0.8\", \"commentary\": \"Brief stuttering hesitation.\"}]"

structured = safe_structured_parse(
      messages=[
        {"role": "system", "content": (
            "You are a strict JSON emitter. "
            "Output strictly valid JSON. Use double quotes for all keys and string values. "
            "Do not wrap output in code fences or Python-style dicts. "
            "Return a JSON array of Disfluency objects. If no disfluencies are detected, return an empty array `[]`."
        )},
        {"role": "user", "content": f"Convert this into valid DisfluencyResults JSON:\n{response}"}
    ],
      response_model=DisfluencyResults,
    )

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



‚ö†Ô∏è InstructorRetryException triggered!
Exception type: InstructorRetryException
Raw exception message:
<failed_attempts>

<generation number="1">
<exception>
    1 validation error for DisfluencyResults
  Invalid JSON: trailing characters at line 1 column 181 [type=json_invalid, input_value='{"time_start": "0.58", "...ering hesitation."\n  }', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/json_invalid
</exception>
<completion>
    <__main__.MockCompletion object at 0x7ee06c868950>
</completion>
</generation>

</failed_attempts>

<last_exception>
    1 validation error for DisfluencyResults
  Invalid JSON: trailing characters at line 1 column 181 [type=json_invalid, input_value='{"time_start": "0.58", "...ering hesitation."\n  }', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/json_invalid
</last_exception>

Args passed to exception: (1 validation error for DisfluencyResults
  Invalid JSON: trailing charac

InstructorRetryException: <failed_attempts>

<generation number="1">
<exception>
    1 validation error for DisfluencyResults
  Invalid JSON: trailing characters at line 1 column 181 [type=json_invalid, input_value='{"time_start": "0.58", "...ering hesitation."\n  }', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/json_invalid
</exception>
<completion>
    <__main__.MockCompletion object at 0x7ee06c868950>
</completion>
</generation>

</failed_attempts>

<last_exception>
    1 validation error for DisfluencyResults
  Invalid JSON: trailing characters at line 1 column 181 [type=json_invalid, input_value='{"time_start": "0.58", "...ering hesitation."\n  }', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/json_invalid
</last_exception>

In [29]:
sample = dataset["train"]["audio"][0]

response = run_stuttering_detection_pipeline(
    model=model,
    processor=processor,
    sample=sample,
    system_prompt=system_prompt,
    question_prompt=question_prompt
)

print("\n=== MODEL RESPONSE ===")
print(response)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[{\"time_start\": \"0.58\", \"time_end\": \"1.24\", \"transcript_segment\": \"myself\", \"disfluency_type\": \"stuttering\", \"confidence_score\": \"0.7\", \"commentary\": \"Short stuttering hesitation.\"}, {\"time_start\": \"1.63\", \"time_end\": \"2.09\", \"transcript_segment\": \"limiting\", \"disfluency_type\": \"stuttering\", \"confidence_score\": \"0.8\", \"commentary\": \"Brief stuttering hesitation.\"}]

‚ö†Ô∏è InstructorRetryException triggered!
Exception type: InstructorRetryException
Raw exception message:
<failed_attempts>

<generation number="1">
<exception>
    1 validation error for DisfluencyResults
  Invalid JSON: key must be a string at line 1 column 2 [type=json_invalid, input_value='{\\"time_start\\": \\"0....ering hesitation."\n  }', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/json_invalid
</exception>
<completion>
    <__main__.MockCompletion object at 0x7ee06caa39b0>
</completion>
</generation>

</failed_attempts>

<last_e

InstructorRetryException: <failed_attempts>

<generation number="1">
<exception>
    1 validation error for DisfluencyResults
  Invalid JSON: key must be a string at line 1 column 2 [type=json_invalid, input_value='{\\"time_start\\": \\"0....ering hesitation."\n  }', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/json_invalid
</exception>
<completion>
    <__main__.MockCompletion object at 0x7ee06caa39b0>
</completion>
</generation>

</failed_attempts>

<last_exception>
    1 validation error for DisfluencyResults
  Invalid JSON: key must be a string at line 1 column 2 [type=json_invalid, input_value='{\\"time_start\\": \\"0....ering hesitation."\n  }', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/json_invalid
</last_exception>

This needs smaller prompt for Qwen

## filtering dataset to only stuttering sample

In [None]:
# dataset['train'] = dataset['train'].add_column("idx", list(range(len(dataset['train']))))

# df = dataset['train'].to_pandas()
# # Based on Ida's code
# # Filter the DataFrame based on the specified criteria using OR and string comparison
# filtered_df = df[(df['Prolongation'] == '3') | (df['SoundRep'] == '3') | (df['WordRep'] == '3') | (df['Block'] == '3')]
# indices = filtered_df.index.tolist()

# filtered_dataset = dataset['train'].select(indices)


## Scaling

In [None]:
# from tqdm import tqdm
# import json, os, time

# # === SETTINGS ===
# output_path = "stuttering_results.jsonl"
# save_every = 5        # how often to checkpoint
# max_retries = 3       # how many times to retry a failed item

# # === LOAD EXISTING PROGRESS (if any) ===
# completed_ids = set()
# if os.path.exists(output_path):
#     with open(output_path, "r") as f:
#         for line in f:
#             try:
#                 record = json.loads(line)
#                 completed_ids.add(record["idx"])
#             except json.JSONDecodeError:
#                 continue
#     print(f"Resuming ‚Äî found {len(completed_ids)} completed items")

# # === RUN SAFELY ===
# for item in tqdm(filtered_dataset):
#     idx = item["idx"]

#     # Skip if already done
#     if idx in completed_ids:
#         continue

#     sample = item["audio"]

#     # Retry wrapper for transient failures
#     for attempt in range(max_retries):
#         try:
#             response = run_stuttering_detection_pipeline(
#                 model=model,
#                 processor=processor,
#                 sample=sample,
#                 system_prompt=system_prompt,
#                 question_prompt=question_prompt,
#             )

#             record = {
#                 "idx": idx,
#                 "response": response,
#                 "timestamp": time.time(),
#             }

#             # Save incrementally
#             with open(output_path, "a") as f:
#                 f.write(json.dumps(record) + "\n")

#             completed_ids.add(idx)
#             break  # success ‚Üí move to next sample

#         except Exception as e:
#             print(f"‚ö†Ô∏è Error on idx={idx} (attempt {attempt+1}/{max_retries}): {e}")
#             time.sleep(2)
#             if attempt + 1 == max_retries:
#                 # Log failure for later inspection
#                 with open("stuttering_failures.log", "a") as ef:
#                     ef.write(f"{idx}\t{str(e)}\n")
