In [None]:
!pip install diarizationlm

Collecting diarizationlm
  Downloading diarizationlm-0.1.5-py3-none-any.whl.metadata (11 kB)
Collecting colortimelog (from diarizationlm)
  Downloading colortimelog-0.0.9-py3-none-any.whl.metadata (1.5 kB)
Collecting word-levenshtein (from diarizationlm)
  Downloading word_levenshtein-0.0.3.tar.gz (7.7 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Downloading diarizationlm-0.1.5-py3-none-any.whl (24 kB)
Downloading colortimelog-0.0.9-py3-none-any.whl (7.9 kB)
Building wheels for collected packages: word-levenshtein
  Building wheel for word-levenshtein (pyproject.toml) ... [?25l[?25hdone
  Created wheel for word-levenshtein: filename=word_levenshtein-0.0.3-cp312-cp312-linux_x86_64.whl size=78409 sha256=87a7fec6768d1884e94fc254ee2c5151e16c5881c69529492a80e6bc89ae4577
  Stored in directory: /root/.cache/pip/wheels/a5/90/af/c9de0d35b502010a9221b531cce461c295a5

In [None]:
from transformers import LlamaForCausalLM, AutoTokenizer
from diarizationlm import utils


In [None]:
import json
import os
import glob
from transformers import LlamaForCausalLM, AutoTokenizer
from diarizationlm import utils

# Configuration
INPUT_FOLDER = "input_test"
OUTPUT_FOLDER = "output_adaptive"
MODEL_NAME = "google/DiarizationLM-8b-Fisher-v2"

In [None]:
import torch

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")

# Check which GPU you're using
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name()}")
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")

    # Check memory usage
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
else:
    print("Running on CPU - this will be very slow!")

CUDA available: True
GPU device: NVIDIA A100-SXM4-40GB
GPU count: 1
Current device: 0
GPU memory allocated: 0.00 GB
GPU memory reserved: 0.00 GB


In [None]:
import json
import os
import glob
import re
import torch
from transformers import AutoTokenizer, LlamaForCausalLM
from diarizationlm import utils

def parse_word_level(json_data):
    """Convert the utterance into perfect 1:1 word entries."""
    if "utterances" not in json_data or len(json_data["utterances"]) != 1:
        raise ValueError("JSON must contain exactly one utterance")

    utt = json_data["utterances"][0]

    words = utt["hyp_text"].split()
    speakers = utt["hyp_spk"].split()

    if len(words) != len(speakers):
        raise ValueError("Word/Speaker mismatch in original JSON")

    return [{"word": w, "spk": s, "pos": i} for i, (w, s) in enumerate(zip(words, speakers))]


def chunk_adaptive(word_entries, overlap_window=5):
    """
    ADAPTIVE STRATEGY:
    - A chunk ends whenever a speaker switch happens.
    - The next chunk *starts 5 words before* the switch (overlap).
    - Allows reconciling speaker corrections in the transition area.
    """
    chunks = []
    start_idx = 0

    for i in range(1, len(word_entries)):
        prev_spk = word_entries[i-1]["spk"]
        curr_spk = word_entries[i]["spk"]

        # Speaker change detected
        if prev_spk != curr_spk:
            end_idx = i  # chunk ends *before* the change word

            chunk = word_entries[start_idx:end_idx]
            if chunk:
                chunks.append(chunk)

            # Next chunk starts overlap_window words before the boundary
            start_idx = max(0, i - overlap_window)

    # Final chunk
    if start_idx < len(word_entries):
        chunks.append(word_entries[start_idx:])

    print(f"Adaptive chunking created {len(chunks)} chunks.")
    return chunks

def chunk_to_hypothesis_optimized(chunk):
    """Speaker-tag-only-on-change format."""
    if not chunk:
        return ""

    parts = []
    current_speaker = None

    for w in chunk:
        if w["spk"] != current_speaker:
            current_speaker = w["spk"]
            parts.append(f"<speaker:{current_speaker}>")
        parts.append(w["word"])

    return " ".join(parts)


def process_chunk_with_diarizationlm(word_chunk, model, tokenizer):
    hypothesis = chunk_to_hypothesis_optimized(word_chunk)
    print(f"   → Processing {len(word_chunk)} words...")

    torch.cuda.empty_cache()

    inputs = tokenizer([hypothesis + " --> "], return_tensors="pt").to("cuda")

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=min(800, int(inputs.input_ids.shape[1] * 1.1)),
            do_sample=False,
            temperature=1.0,
            num_beams=1,
            pad_token_id=tokenizer.eos_token_id,
        )

    completion = tokenizer.batch_decode(
        output[:, inputs.input_ids.shape[1]:],
        skip_special_tokens=True
    )[0]

    try:
        transferred = utils.transfer_llm_completion(completion, hypothesis)
    except:
        print("WARNING: transfer_llm_completion failed.")
        transferred = hypothesis

    return transferred


def parse_diarization_output_optimized(output, original_chunk):
    pattern = r"<speaker:(\d+)>\s*([^<]*)"
    matches = re.findall(pattern, output)

    if not matches:
        print("   WARNING: No speaker tags detected, fallback using original speakers.")
        return [w["spk"] for w in original_chunk]

    assigned = []
    for spk, text in matches:
        words = text.strip().split()
        assigned.extend([spk] * len(words))

    # Pad if model produced fewer words
    if len(assigned) < len(original_chunk):
        last_spk = assigned[-1] if assigned else original_chunk[0]["spk"]
        assigned.extend([last_spk] * (len(original_chunk) - len(assigned)))

    return assigned[:len(original_chunk)]


def apply_chunk_speaker_updates(global_list, chunk, updated):
    """
    Overwrites speaker labels in global words.
    Later chunks have priority due to adaptive overlap.
    """
    changes = 0
    for w, new_spk in zip(chunk, updated):
        if w["spk"] != new_spk:
            w["spk"] = new_spk
            changes += 1
    return changes



def rebuild_json(global_words, original_json):
    words = [x["word"] for x in global_words]
    speakers = [x["spk"] for x in global_words]

    utt = original_json["utterances"][0].copy()
    utt["hyp_text"] = " ".join(words)
    utt["hyp_spk"] = " ".join(speakers)

    out = original_json.copy()
    out["utterances"] = [utt]
    return out


def process_single_json_file(json_file, model, tokenizer, output_folder):
    print(f"\nProcessing {os.path.basename(json_file)}")

    with open(json_file, "r") as f:
        data = json.load(f)

    # Parse words
    words = parse_word_level(data)

    # Adaptive chunking by speaker changes + overlap
    chunks = chunk_adaptive(words, overlap_window=5)

    global_words = words.copy()
    total_changes = 0

    for idx, chunk in enumerate(chunks):
        print(f"Chunk {idx+1}/{len(chunks)} (size {len(chunk)})")

        output = process_chunk_with_diarizationlm(chunk, model, tokenizer)
        updated_speakers = parse_diarization_output_optimized(output, chunk)

        # Overwrites earlier labels when overlaps occur
        total_changes += apply_chunk_speaker_updates(global_words, chunk, updated_speakers)

    print(f"   Total speaker label changes: {total_changes}")

    # Rebuild JSON
    output_json = rebuild_json(global_words, data)

    out_path = os.path.join(output_folder, os.path.basename(json_file))
    with open(out_path, "w") as f:
        json.dump(output_json, f, indent=2)

    print(f"Saved → {out_path}")
    return output_json


def process_all_json_files(input_folder, output_folder):
    print("Loading DiarizationLM model...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = LlamaForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto"
    )

    json_files = glob.glob(os.path.join(input_folder, "*.json"))
    print(f"Found {len(json_files)} JSON files.\n")

    os.makedirs(output_folder, exist_ok=True)

    count = 0
    for jf in json_files:
        if process_single_json_file(jf, model, tokenizer, output_folder) is not None:
            count += 1

    print(f"\nCompleted: {count}/{len(json_files)} files processed successfully.")


if __name__ == "__main__":
    process_all_json_files(INPUT_FOLDER, OUTPUT_FOLDER)


Loading DiarizationLM model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Found 17 JSON files.


Processing SBC017.json
Adaptive chunking created 45 chunks.
Chunk 1/45 (size 200)
   → Processing 200 words...
Chunk 2/45 (size 53)
   → Processing 53 words...
Chunk 3/45 (size 21)
   → Processing 21 words...
Chunk 4/45 (size 36)
   → Processing 36 words...
Chunk 5/45 (size 22)
   → Processing 22 words...
Chunk 6/45 (size 67)
   → Processing 67 words...
Chunk 7/45 (size 33)
   → Processing 33 words...
Chunk 8/45 (size 7)
   → Processing 7 words...
Chunk 9/45 (size 126)
   → Processing 126 words...
Chunk 10/45 (size 88)
   → Processing 88 words...
Chunk 11/45 (size 60)
   → Processing 60 words...
Skipping meaningless speaker token: <speaker:>
Chunk 12/45 (size 32)
   → Processing 32 words...
Skipping meaningless speaker token: <speaker:>
Chunk 13/45 (size 76)
   → Processing 76 words...
Chunk 14/45 (size 162)
   → Processing 162 words...
Chunk 15/45 (size 70)
   → Processing 70 words...
Chunk 16/45 (size 21)
   → Processing 21 words...
Chunk 17/45 (size 329)
   → 