In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
torch_device = torch.device(device)
print(torch_device)

cuda:0
cuda:0


# Install

- Database

In [None]:
!pip install -q pymongo

- Speech2Text

In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade transformers datasets[audio] accelerate
# !pip install -q --upgrade transformers accelerate
!pip install -q torch torchvision torchaudio
!pip install -q pyannote.audio
!pip install -q -U openai-whisper

# Import

### Framework

In [None]:
import os
from pathlib import Path
from io import BytesIO
import librosa

import re
import string
import json
from google.colab import output
import time
import pandas as pd
import numpy as np
import unicodedata
import collections
from copy import deepcopy

In [None]:
# Database
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
import gridfs
from bson import ObjectId

In [None]:
from IPython.display import Javascript
from IPython.display import Audio

In [None]:
# Speech2Text
import shutil
import subprocess
from IPython.display import Javascript
from IPython.display import Audio
from base64 import b64decode

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import whisper
from whisper import load_model

# Diarization
from pyannote.audio import Pipeline

In [None]:
# SOAP Gen
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    GenerationConfig,
    TrainingArguments,
    Trainer,
    PegasusForConditionalGeneration,
    PegasusTokenizer,
    PegasusTokenizerFast,
)

from peft import PeftModel, PeftConfig, get_peft_model_state_dict

# Implement

### Utils

In [None]:
def save_to_json(data, filename):
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Saved to: {filename}")

def load_from_json(filename):
    """
    Load the conversation from a JSON file.
    """
    with open(filename, 'r') as f:
        return json.load(f)

def convert_to_wav(input_path: str, output_path: str) -> str:
    # Ensure the output path has .wav extension
    output_wav_path = str(Path(output_path).with_suffix(".wav"))

    subprocess.run([
        "ffmpeg", "-y", "-i", input_path,
        "-ac", "1", "-ar", "16000", output_wav_path
    ], check=True)

    return output_wav_path

### Database

In [None]:
# === Connect to DB ===
def connect_to_db(uri: str, db_name: str):
    client = MongoClient(uri)
    db = client[db_name]
    try:
        client.admin.command('ping')
        print("Pinged your deployment. You successfully connected to MongoDB!")
    except Exception as e:
        print(e)
    return db, gridfs.GridFS(db)

# === Insert Audio File ===
def upload_audio(fs, file_path: str):
    with open(file_path, "rb") as f:
        audio_id = fs.put(f, filename=os.path.basename(file_path))
    return audio_id

# === Insert Full Conversation Record ===
def insert_conversation(db, name, audio_id: ObjectId, transcript: list, word_transcript: list, summary: dict):
    conversation = {
        "name": name,
        "audio_id": audio_id,
        "transcript": transcript,
        "word_transcript": word_transcript,
        "summary": summary
    }
    result = db["conversation"].insert_one(conversation)
    return result.inserted_id

# === Retrieve a Record ===
def get_conversation(db, conversation_id: str):
    return db["conversation"].find_one({"_id": ObjectId(conversation_id)})

# === Get Audio by ID ===
def get_audio(fs, audio_id: str):
    return fs.get(ObjectId(audio_id)).read()

# === Download Audio by ID ===
def download_audio(fs, audio_id: str, save_path: str):
    data = fs.get(ObjectId(audio_id)).read()
    with open(save_path, "wb") as f:
        f.write(data)

### Dialogue2Text

In [None]:
def transcribe_with_whisper(audio_path, model, language="en"):
    """Run Whisper to get a list of tokens with timestamps."""
    # word_timestamps=True for word-level timing
    result = model.transcribe(
        audio_path,
        word_timestamps=True,
        language=language  # "en", "vi", "ja"
    )
    # result["segments"] is a list of dicts with words inside
    tokens = []
    for seg in result["segments"]:
        for word_info in seg["words"]:
            tokens.append({
                "word": word_info["word"].strip(),
                "start": round(word_info["start"], 2),
                "end":   round(word_info["end"], 2)
            })
    return tokens

def diarize_with_pyannote(audio_path, pipeline, device="cuda"):
    """Run pyannote speaker diarization pipeline."""
    pipeline.to(torch.device(device))
    diarization = pipeline({"audio": audio_path})

    segments = [
        {
            "start": round(turn.start, 2),
            "end": round(turn.end, 2),
            "speaker": speaker
        }
        for turn, _, speaker in diarization.itertracks(yield_label=True)
    ]

    return segments

def assign_speakers(tokens, segments):
    """
    For each token, find the diarization segment it falls into.
    If no segment covers its start time, assign 'UNK'.
    """
    diarized_tokens = []
    idx = 0
    # sort segments by start time
    segments = sorted(segments, key=lambda x: x["start"])
    for token in tokens:
        # advance idx until segment might cover token
        while idx + 1 < len(segments) and segments[idx]["end"] < token["start"]:
            idx += 1
        seg = segments[idx]
        speaker = seg["speaker"] if seg["start"] <= token["start"] <= seg["end"] else "UNK"
        diarized_tokens.append({**token, "speaker": speaker})
    return diarized_tokens

def build_diarized_transcript(diarized_tokens):
    """
    Group contiguous tokens with same speaker into utterances.
    Returns list of {speaker, start, end, text}.
    """
    if not diarized_tokens:
        return []
    utterances = []
    cur = {
        "speaker": diarized_tokens[0]["speaker"],
        "start":   diarized_tokens[0]["start"],
        "end":     diarized_tokens[0]["end"],
        "text":    diarized_tokens[0]["word"]
    }
    for tok in diarized_tokens[1:]:
        if tok["speaker"] == cur["speaker"]:
            cur["end"] = tok["end"]
            cur["text"] += " " + tok["word"]
        else:
            utterances.append(cur)
            cur = {
                "speaker": tok["speaker"],
                "start":   tok["start"],
                "end":     tok["end"],
                "text":    tok["word"]
            }
    utterances.append(cur)
    return utterances

def merge_unk_into_next(utterances):
    """
    Given a list of {'speaker','start','end','text'} utterances,
    merge any UNK utterance into the next real speaker.
    """
    merged = []
    i = 0
    while i < len(utterances):
        utt = utterances[i]
        # If this is an UNK and there *is* a following utterance, merge it there
        if utt["speaker"] == "UNK" and i + 1 < len(utterances):
            next_utt = utterances[i + 1]
            # prepend the UNK text and adjust the start time
            next_utt["text"]  = utt["text"] + " " + next_utt["text"]
            next_utt["start"] = utt["start"]
            # we skip appending utt itself
        else:
            # regular speaker, just keep it
            merged.append(utt)
        i += 1
    return merged

def add_utterance_ids(utterances, prefix="U"):
    """
    Parameters:
        utterances (list of dict): List of utterance dictionaries.
        prefix (str): Prefix for utterance IDs, default is 'U'.

    Returns:
        list of dict: Modified list with 'utterance_id' added to each item.
    """
    for i, utt in enumerate(utterances, start=1):
        utt["utterance_id"] = f"{prefix}{i}"
    return utterances

def process_audio(audio_path, model_whisper, pipeline_diarization):
    # Whisper
    tokens = transcribe_with_whisper(audio_path, model_whisper, "en")

    # Pyannote
    segments = diarize_with_pyannote(audio_path, pipeline_diarization, device="cuda")

    # Process
    diarized_tokens = assign_speakers(tokens, segments)
    raw_utterances = build_diarized_transcript(diarized_tokens)
    clean_utterances = merge_unk_into_next(raw_utterances)
    clean_utterances = add_utterance_ids(clean_utterances, "U")

    return tokens, clean_utterances

### Role Classifier

In [None]:
# Doctor indicators
doctor_keywords = [
    # Questions and inquiries
    "what brings you", "how are you feeling", "how long have you", "when did this start",
    "can you describe", "tell me about", "any other symptoms", "have you experienced",
    "do you have any", "are you taking", "have you tried", "how would you rate",
    "on a scale of", "does it hurt when", "can you point to", "how often do you",

    # Medical examination language
    "let me examine", "i'm going to", "let me check", "i need to", "let me listen",
    "take a deep breath", "say ah", "follow my finger", "look up", "look down",
    "turn your head", "can you lift", "does this hurt", "feel any pressure",

    # Medical recommendations and instructions
    "i recommend", "you should", "i suggest", "my advice", "you need to",
    "take this", "apply this", "rest for", "avoid", "come back in",
    "follow up", "schedule", "return if", "call if", "monitor",

    # Medical terminology and diagnosis
    "diagnosis", "condition", "infection", "inflammation", "prescription",
    "medication", "treatment", "therapy", "procedure", "test results",
    "blood work", "x-ray", "scan", "allergy", "dosage", "side effects",
    "medical history", "family history", "chronic", "acute", "symptoms indicate",

    # Professional phrases
    "in my opinion", "based on", "it appears", "it looks like", "i believe",
    "we need to rule out", "differential diagnosis", "likely cause", "i suspect"
]

# Comprehensive patient indicators
patient_keywords = [
    # Personal symptoms and feelings
    "i have", "i feel", "i'm experiencing", "i've been having", "i get",
    "i notice", "i can't", "i'm unable to", "it hurts", "it's painful",
    "i'm worried", "i'm concerned", "i think", "i believe", "i'm afraid",

    # Pain and discomfort descriptions
    "hurts", "pain", "painful", "ache", "aching", "sore", "tender",
    "burning", "stinging", "throbbing", "sharp", "dull", "cramping",
    "tight", "pressure", "uncomfortable", "bothering me", "killing me",

    # Symptom descriptions
    "sick", "nauseous", "dizzy", "tired", "weak", "fever", "chills",
    "headache", "stomach ache", "runny nose", "cough", "congested",
    "swollen", "rash", "itchy", "blurry", "ringing", "numbness",

    # Personal references and possessives
    "my head", "my back", "my stomach", "my chest", "my throat", "my arm",
    "my leg", "my eye", "my ear", "my skin", "my heart", "my breathing",

    # Timeline and frequency from patient perspective
    "started yesterday", "been going on", "happens when", "gets worse",
    "feels better", "comes and goes", "all the time", "at night", "in the morning",
    "after eating", "before bed", "during", "since", "for days", "for weeks",

    # Lifestyle and personal context
    "i work", "i sleep", "i eat", "i drink", "i smoke", "i exercise",
    "i live", "i usually", "normally i", "my job", "my family", "my wife",
    "my husband", "my kids", "at home", "at work"
]

def role_classification(dialogue, doctor_keywords, patient_keywords):
    """
    Fallback classification using rule-based approach if API fails.
    """
    speakers = list(set(segment["speaker"] for segment in dialogue))

    if len(speakers) != 2:
        # If not exactly 2 speakers, return default mapping
        return {speaker: "Unknown" for speaker in speakers}

    speaker_analysis = {}

    for speaker in speakers:
        speaker_texts = [seg["text"] for seg in dialogue if seg["speaker"] == speaker]
        combined_text = " ".join(speaker_texts).lower()

        # Rule-based scoring
        doctor_score = 0
        patient_score = 0

        for keyword in doctor_keywords:
            if keyword in combined_text:
                doctor_score += 1

        for keyword in patient_keywords:
            if keyword in combined_text:
                patient_score += 1

        # # First speaker is often doctor (greeting pattern)
        # if dialogue[0]["speaker"] == speaker and any(word in dialogue[0]["text"].lower()
        #                                            for word in ["hello", "hi", "good"]):
        #     doctor_score += 2

        speaker_analysis[speaker] = {"doctor_score": doctor_score, "patient_score": patient_score}

    # Assign roles based on scores
    result = {}
    speakers_by_doctor_score = sorted(speakers,
                                    key=lambda x: speaker_analysis[x]["doctor_score"],
                                    reverse=True)

    result[speakers_by_doctor_score[0]] = "Doctor"
    result[speakers_by_doctor_score[1]] = "Patient"

    return result

def replace_speaker_labels(transcript, speaker_labels):
    """
    Replace speaker labels in transcript with classified roles.

    Args:
        transcript List[Dict[str, Any]]: List of dialogue segments with speaker, start, end, and text
        speaker_labels : Dict[str, str]: Dictionary mapping original speaker IDs to roles (e.g., {"SPEAKER_00": "Doctor", "SPEAKER_01": "Patient"})

    Returns:
        -> List[Dict[str, Any]]: List of dialogue segments with updated speaker labels
    """
    updated_transcript = []

    for segment in transcript:
        # Create a copy of the segment to avoid modifying the original
        updated_segment = segment.copy()

        # Replace the speaker label with the classified role
        original_speaker = segment["speaker"]
        if original_speaker in speaker_labels:
            updated_segment["speaker"] = speaker_labels[original_speaker]
        else:
            # Keep original label if not found in speaker_labels
            print(f"Warning: Speaker '{original_speaker}' not found in speaker_labels. Keeping original label.")

        updated_transcript.append(updated_segment)

    return updated_transcript

def classify_speakers(transcript, doctor_keywords, patient_keywords):
    transcript_speakers = role_classification(transcript, doctor_keywords, patient_keywords)
    transcript_labelled = replace_speaker_labels(transcript, transcript_speakers)
    return transcript_labelled

### SOAP Gen

In [None]:
def transcript2string(utterances: list[dict]) -> str:
    dialogue_lines = [
        f"{entry['speaker']}: {entry['text'].strip()}"
        for entry in utterances
        if 'speaker' in entry and 'text' in entry
    ]
    return "\n".join(dialogue_lines)

In [None]:
def process_soap_traceability(dialogue, tok, tok_fast, ft_model):
    # split into utterances + IDs --------------------------------------
    parts = re.split(r'(Doctor:|Patient:)', dialogue)[1:]
    utterances = [parts[i] + parts[i + 1] for i in range(0, len(parts), 2)]
    utterance_ids = [f"U{i+1}" for i in range(len(utterances))]
    # dict: utt - utt_id
    utt_dict = {uid: utt.strip() for uid, utt in zip(utterance_ids, utterances)}

    # --- map *encoder* tokens → utterance -----------------------------
    enc_inputs   = tok(dialogue, return_tensors="pt", max_length=512, truncation=True).to(device)
    enc_tokens   = tok.convert_ids_to_tokens(enc_inputs["input_ids"][0])

    token_to_utt = []
    cur_utt_idx  = 0
    utt_starts   = []
    for i, utt in enumerate(utterances):
        first_piece = tok.tokenize(utt.strip())[0]
        try:
            pos = enc_tokens.index(first_piece,
                                  utt_starts[-1] + 1 if i else 0)
        except ValueError:
            pos = utt_starts[-1] if utt_starts else 0
        utt_starts.append(pos)

    for i in range(len(enc_tokens)):
        j = max(j for j, start in enumerate(utt_starts) if i >= start)
        token_to_utt.append(utterance_ids[j])

    # ------------------------------------------------------------------
    # 3) GENERATE SOAP summary -----------------------------------------
    gen_ids = ft_model.generate(**enc_inputs, generation_config=gen_cfg)
    summary_text   = tok_fast.decode(gen_ids[0],
                                    skip_special_tokens=True,
                                    clean_up_tokenization_spaces=True)

    print(f"[INFO] Generated summary length: {len(summary_text)} characters, {len(gen_ids[0])} tokens")

    # Forced decoding to get attentions
    with torch.no_grad():
        outs = ft_model(**enc_inputs, labels=gen_ids, output_attentions=True)

    # (layers, b, h, dec_len, enc_len)
    all_xattn = torch.stack(outs.cross_attentions, dim=0)
    avg_xattn = all_xattn.mean(dim=(0, 2))[0]                    # (dec_len, enc_len)

    # ---------------------------------------------------------------
    # Decoder tokens without specials
    enc_summary = tok_fast(summary_text,
                          add_special_tokens=False,
                          return_offsets_mapping=True)

    dec_tokens  = tok_fast.convert_ids_to_tokens(enc_summary["input_ids"])
    offsets     = enc_summary["offset_mapping"]           # (dec_len, 2)

    # ---------------------------------------------------------------
    # Align rows ↔ dec_tokens 1-to-1
    bos_id, eos_id = tok_fast.bos_token_id, tok_fast.eos_token_id
    keep_mask_full = (gen_ids[0] != bos_id) & (gen_ids[0] != eos_id)
    keep_mask = keep_mask_full[:-1]  # Trim 1 element: drop the final position because EOS has **no** cross-attn row

    # avg_xattn = avg_xattn[keep_mask]
    # avg_xattn = avg_xattn[:len(dec_tokens)]
    avg_xattn = avg_xattn[:len(dec_tokens), :]
    # avg_xattn = avg_xattn[1:]

    if avg_xattn.shape[0] != len(dec_tokens):
        print("Warning: shape mismatch after filtering special tokens:",
              avg_xattn.shape[0], "rows vs", len(dec_tokens), "decoder tokens")

    tok_char_starts = [off[0] for off in offsets]


    # ------------------------------------------------------------------
    # 4) PARSE SOAP into sections + sentences --------------------------
    soap_sec_pat = re.compile(r'\b([SOAP]):')
    matches      = list(soap_sec_pat.finditer(summary_text))
    sent_entries = []                       # [{section,S/O/A/P; idx; start; end; text}]

    for idx, m in enumerate(matches):
        sec = m.group(1)
        start_txt = m.end()
        end_txt   = matches[idx+1].start() if idx+1 < len(matches) else len(summary_text)
        section_text = summary_text[start_txt:end_txt].strip()

        # split into sentences
        sent_pats = re.split(r'(?<=[.!?])\s+', section_text)
        cursor = start_txt
        local_idx = 0
        for s in sent_pats:
            s_clean = s.strip()
            if not s_clean:
                continue
            s_start = summary_text.find(s_clean, cursor, end_txt)
            s_end   = s_start + len(s_clean)
            sent_entries.append({
                "section":   sec,
                "sent_idx":  local_idx,
                "start":     s_start,
                "end":       s_end,
                "text":      s_clean,
            })
            cursor     = s_end + 1
            local_idx += 1

    # map decoder-token pos → sentence index ---------------------------
    tok2sent = []
    cur_sent = 0
    for pos in tok_char_starts:
        while (cur_sent + 1 < len(sent_entries) and
              pos >= sent_entries[cur_sent + 1]["start"]):
            cur_sent += 1
        tok2sent.append(cur_sent)

    # ------------------------------------------------------------------
    # 5) CROSS-ATTENTION -----------------------------------------------

    SPECIALS = set(tok_fast.all_special_tokens)
    SECTION_TOKENS = {"▁S", "▁O", "▁A", "▁P", "S", "O", "A", "P"}
    STOP_WORDS = {
        "the", "a", "an", "and", "or", "but", "if", "with", "of", "for", "to",
        "in", "on", "at", "by", "is", "are", "was", "were", "be", "been", "am",
        "it", "this", "that", "these", "those", "as", "has", "have", "had",
    }

    def all_punct(text: str) -> bool:
        """True if every character in text is Unicode punctuation."""
        return all(unicodedata.category(ch).startswith("P") for ch in text)

    def is_skip(tok):
        t = tok.replace('▁', '').lower()

        return (
            t in string.punctuation        # ← punctuation
            or t in STOP_WORDS             # ← stop-word list you defined
            or tok in SPECIALS             # ← <pad>, </s>, <s> …
            or tok in SECTION_TOKENS
            or t == ""                     # ← empty after stripping ▁
            or t in {"", "<pad>", "</s>"}
            or all_punct(t)
        )

    top_k_token_att = 5 # keep k strongest attentions
    # ------------------------------------------------------------------
    # 6) TOKEN-LEVEL vote (top-k enc tokens) ---------------------------
    token_vote_utt_dict = {}      # dec_pos -> winner utterance (e.g. 3: "U1")
    token_top_utts      = {}      # dec_pos -> [k] utterances  (e.g. 3: ['U1', 'U1', 'U1']) # debug only

    for dpos, scores in enumerate(avg_xattn):            # dpos indexes dec_tokens
        if is_skip(dec_tokens[dpos]):                    # skip punc / stop-words / specials
            continue                                     #  ↳ not stored anywhere


        # Sort encoder tokens by attention scores, take top-k non-skipped ones.
        keep = []
        for idx in torch.argsort(scores, descending=True):
            if not is_skip(enc_tokens[idx]):
                keep.append(int(idx))
                if len(keep) == top_k_token_att:         # got k, stop
                    break
        if not keep:                                     # Rare case: all top-k are skipped
            continue

        # Map encoder token indices to utterance IDs using token_to_utt.
        utts   = [token_to_utt[i] for i in keep]
        win    = collections.Counter(utts).most_common(1)[0][0]

        token_vote_utt_dict[dpos] = win
        token_top_utts[dpos]      = utts                # debug only

    # ------------------------------------------------------------------
    # 7) SENTENCE-LEVEL vote -------------------------------------------
    # Map decoder token → sentence index
    tok2sent = []
    cur_sent = 0
    for start_char, _ in offsets:                        # offsets length == dec_tokens
        while cur_sent + 1 < len(sent_entries) and start_char >= sent_entries[cur_sent + 1]["start"]:
            cur_sent += 1
        tok2sent.append(cur_sent)

    # Majority vote for each sentence
    for global_idx, s in enumerate(sent_entries):
        winners = [token_vote_utt_dict[d]
            for d, sid in enumerate(tok2sent)
                if sid == global_idx and d in token_vote_utt_dict]
        s["utterance_id"] = (
            collections.Counter(winners).most_common(1)[0][0] if winners else "Unknown"
        )

    # ------------------------------------------------------------------
    # 8) BUILD FINAL JSON ----------------------------------------------
    final = {"utterances": utt_dict,        # 🔹NEW section
            "S": [], "O": [], "A": [], "P": []}

    for sent in sent_entries:
        final[sent["section"]].append({
            "sentence_idx":  str(sent["sent_idx"]),
            "sentence_text": sent["text"],
            "utterance_id":  sent["utterance_id"],
        })

    return final

# Initialize

In [None]:
from google.colab import userdata

### Database

In [None]:
uri = userdata.get('URI_MONGODB')
db, fs = connect_to_db(uri, "thesis")

Pinged your deployment. You successfully connected to MongoDB!


### Dialogue2Text

In [None]:
# Whisper
model_whisper="large-v3"
model_whisper = load_model(model_whisper, device=device)

In [None]:
# Pyannote
HF_TOKEN = userdata.get('HF_TOKEN_2')
pipeline_name = "pyannote/speaker-diarization"
pipeline_diarization = Pipeline.from_pretrained(pipeline_name, use_auth_token=HF_TOKEN).to(torch_device)

DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _speechbrain_save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _speechbrain_load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _recover
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../root/.cache/torch/pyannote/models--pyannote--segmentation/snapshots/c4c8ceafcbb3a7a280c2d357aee9fbc9b0be7f9b/pytorch_model.bin`
INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'


Model was trained with pyannote.audio 0.0.1, yours is 3.3.2. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.6.0+cu124. Bad things might happen unless you revert torch to 1.x.


DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _load
DEBUG:speechbrain.utils.checkpoints:Registered parameter transfer hook for _load
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for load_if_possible
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root

### SOAP Gen

In [None]:
fine_tune_path = '/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM'
model_name = 'pegasus' # ADJUST
sub_model_name = 'pegasus_xsum' # ADJUST
checkpoints_dir = f"{fine_tune_path}/{model_name}/{sub_model_name}/lora_1" # ADJUST
checkpoints_path = f"{checkpoints_dir}/checkpoints"
final_checkpoints_path = f"{checkpoints_dir}/final_checkpoints"
summary_path = f"{checkpoints_dir}/summary"

print(f"final_checkpoints_path: {final_checkpoints_path}")

final_checkpoints_path: /content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1/final_checkpoints


In [None]:
# PeftConfig
perf_config = PeftConfig.from_pretrained(final_checkpoints_path)

# Tokenizer
tok = AutoTokenizer.from_pretrained(final_checkpoints_path)
tok_fast   = PegasusTokenizerFast.from_pretrained(final_checkpoints_path)

# FT model
ft_base = AutoModelForSeq2SeqLM.from_pretrained(perf_config.base_model_name_or_path, return_dict=True, device_map='auto')
ft_model = PeftModel.from_pretrained(ft_base, final_checkpoints_path)

config.json:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-xsum and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


generation_config.json:   0%|          | 0.00/259 [00:00<?, ?B/s]

In [None]:
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 428

gen_cfg = deepcopy(ft_model.generation_config)
gen_cfg.max_new_tokens = MAX_TARGET_LEN
gen_cfg.num_beams      = 1
gen_cfg.do_sample      = False
gen_cfg.early_stopping = False

gen_cfg

GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "forced_eos_token_id": 1,
  "length_penalty": 0.6,
  "max_length": 64,
  "max_new_tokens": 428,
  "pad_token_id": 0
}

# Main

- Input audio

In [None]:
def final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords):
    audio_path = f"/content/drive/MyDrive/ClinicalNotesGen/Data/audios/en/wav/{name}.wav"
    audio_id = upload_audio(fs, audio_path)
    print("audio_id:", audio_id)
    transcript_word, transcript = process_audio(audio_path, model_whisper, pipeline_diarization)
    transcript_labelled = classify_speakers(transcript, doctor_keywords, patient_keywords)
    dialogue = transcript2string(transcript_labelled)
    summary = process_soap_traceability(dialogue, tok, tok_fast, ft_model)
    conversation_id = insert_conversation(db, name, audio_id, transcript, transcript_word, summary)
    print("Inserted conversation ID:", conversation_id)

In [None]:
name = 'fever_stomach'
final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords)

audio_id: 685a61637772f10be9ac8dd3


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

[INFO] Generated summary length: 792 characters, 164 tokens
Inserted conversation ID: 685a61777772f10be9ac8ddd


In [None]:
name = 'encounter_fever'
final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords)

audio_id: 685a60b07772f10be9ac8c82


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

[INFO] Generated summary length: 719 characters, 153 tokens
Inserted conversation ID: 685a60bf7772f10be9ac8c8a


In [None]:
name = 'abdominal_pain_history'
final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords)

audio_id: 685a61777772f10be9ac8dde


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

[INFO] Generated summary length: 1055 characters, 219 tokens
Inserted conversation ID: 685a62247772f10be9ac8f4b


In [None]:
name = 'sexual_health_history'
final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords)

audio_id: 685a62247772f10be9ac8f4c


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

[INFO] Generated summary length: 1388 characters, 252 tokens


OperationFailure: you are over your space quota, using 525 MB of 512 MB, full error: {'ok': 0, 'errmsg': 'you are over your space quota, using 525 MB of 512 MB', 'code': 8000, 'codeName': 'AtlasError'}

In [None]:
name = 'encounter_chest_pain'
final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords)

audio_id: 685a5fa27772f10be9ac8c56


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

[INFO] Generated summary length: 1137 characters, 223 tokens
Inserted conversation ID: 685a60217772f10be9ac8c81


In [None]:
name = 'type_2_diabetes'
final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords)

audio_id: 685a60bf7772f10be9ac8c8b


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

[INFO] Generated summary length: 1174 characters, 220 tokens
Inserted conversation ID: 685a61477772f10be9ac8dd2


In [None]:
# name = 'encounter_joint_pain'
# final_pipeline(name, fs, model_whisper, pipeline_diarization, doctor_keywords, patient_keywords)

audio_id: 685a5efd7772f10be9ac8c35


INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

ValueError: max() arg is an empty sequence

# Testing

### Retrieve conversation

In [None]:
conversation_id = "685a61477772f10be9ac8dd2"
record = get_conversation(db, str(conversation_id))
audio_binary = get_audio(fs, record["audio_id"])

In [None]:
Audio(data=audio_binary, autoplay=True)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
record["summary"]

{'utterances': {'U1': "Doctor: Hi there, my name's Leah, I'm one of the junior doctors working in the GP surgery. Is it okay if I just check your name and date of birth please? Yeah,",
  'U2': "Patient: so it's Camilla, Camilla Weldon, and it's the 3rd of May 1977. Nice to meet you. Is it okay if I call you Camilla today? Yeah, of course. Fabulous.",
  'U3': 'Doctor: So how can I help you today Camilla?',
  'U4': 'Patient: Yeah, so the doctor I saw last week, he rang me yesterday to say that I had some blood tests and he just said that the blood tests said I had diabetes. So it was just to come and have a chat to you about, really about that. Okay,',
  'U5': 'Doctor: so how are you feeling about being told that news over the phone?',
  'U6': "Patient: To be honest, I was relieved because at least I sort of now know what's going on.",
  'U7': "Patient: But I have to say I'm a little bit anxious because I don't know a lot about diabetes but I'm not sure what's going on. I know that it is

### Testing (step-by-step)

In [None]:
# === Upload audio ===
name = 'fever_stomach'
audio_path = f"/content/drive/MyDrive/ClinicalNotesGen/Data/audios/en/wav/{name}.wav"
audio_id = upload_audio(fs, audio_path)
audio_id

ObjectId('685a57017772f10be9ac8ab8')

In [None]:
# Fetch audio as binary
audio_binary = get_audio(fs, audio_id)

# Play it in Jupyter
Audio(data=audio_binary, autoplay=True)

Output hidden; open in https://colab.research.google.com to view.

- Dialogue2Text

In [None]:
transcript_word, transcript = process_audio(audio_path, model_whisper, pipeline_diarization)

INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/hyperparams.yaml'
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
DEBUG:speechbrain.utils.parameter_transfer:Collecting files (or symlinks) for pretraining in /root/.cache/torch/pyannote/speechbrain.
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["embedding_model"] = /root/.cache/torch/pyannote/speechbrain/embedding_model.ckpt
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Using symlink found at '/root/.cache/torch/pyannote/speechbrain/mean_var_norm_emb.ckpt'
DEBUG:speechbrain.utils.parameter_transfer:Set local path in self.paths["mean_var_norm_emb"] = /root/.cache/torch/pyannote/speechbrain/mean_var_norm_

In [None]:
transcript_word

[{'word': 'Hello,', 'start': np.float64(4.76), 'end': np.float64(5.44)},
 {'word': 'Mr.', 'start': np.float64(5.62), 'end': np.float64(5.86)},
 {'word': 'McKay.', 'start': np.float64(6.08), 'end': np.float64(6.42)},
 {'word': 'What', 'start': np.float64(7.28), 'end': np.float64(7.96)},
 {'word': 'brings', 'start': np.float64(7.96), 'end': np.float64(8.2)},
 {'word': 'you', 'start': np.float64(8.2), 'end': np.float64(8.48)},
 {'word': 'here', 'start': np.float64(8.48), 'end': np.float64(8.72)},
 {'word': 'today?', 'start': np.float64(8.72), 'end': np.float64(9.0)},
 {'word': 'I', 'start': np.float64(9.74), 'end': np.float64(10.42)},
 {'word': 'have', 'start': np.float64(10.42), 'end': np.float64(10.7)},
 {'word': 'a', 'start': np.float64(10.7), 'end': np.float64(10.84)},
 {'word': 'fever', 'start': np.float64(10.84), 'end': np.float64(11.3)},
 {'word': 'and', 'start': np.float64(11.3), 'end': np.float64(11.68)},
 {'word': 'a', 'start': np.float64(11.68), 'end': np.float64(11.9)},
 {'wor

In [None]:
transcript

[{'speaker': 'SPEAKER_00',
  'start': np.float64(4.76),
  'end': np.float64(6.42),
  'text': 'Hello, Mr. McKay.',
  'utterance_id': 'U1'},
 {'speaker': 'SPEAKER_00',
  'start': np.float64(7.28),
  'end': np.float64(9.0),
  'text': 'What brings you here today?',
  'utterance_id': 'U2'},
 {'speaker': 'SPEAKER_01',
  'start': np.float64(9.74),
  'end': np.float64(14.16),
  'text': 'I have a fever and a sore stomach. Okay, Tony.',
  'utterance_id': 'U3'},
 {'speaker': 'SPEAKER_00',
  'start': np.float64(15.28),
  'end': np.float64(21.3),
  'text': "I see your temperature is 104 degrees. That's very high.",
  'utterance_id': 'U4'},
 {'speaker': 'SPEAKER_01',
  'start': np.float64(22.24),
  'end': np.float64(27.2),
  'text': 'Yes, I feel very dizzy and nauseous. Did you get sick?',
  'utterance_id': 'U5'},
 {'speaker': 'SPEAKER_01',
  'start': np.float64(28.34),
  'end': np.float64(35.0),
  'text': 'Yes, I vomited twice this morning. Did you have any diarrhea? Yes,',
  'utterance_id': 'U6'},

- Role Classifier

In [None]:
transcript_labelled = classify_speakers(transcript, doctor_keywords, patient_keywords)

In [None]:
transcript_labelled

[{'speaker': 'Doctor',
  'start': np.float64(4.76),
  'end': np.float64(6.42),
  'text': 'Hello, Mr. McKay.',
  'utterance_id': 'U1'},
 {'speaker': 'Doctor',
  'start': np.float64(7.28),
  'end': np.float64(9.0),
  'text': 'What brings you here today?',
  'utterance_id': 'U2'},
 {'speaker': 'Patient',
  'start': np.float64(9.74),
  'end': np.float64(14.16),
  'text': 'I have a fever and a sore stomach. Okay, Tony.',
  'utterance_id': 'U3'},
 {'speaker': 'Doctor',
  'start': np.float64(15.28),
  'end': np.float64(21.3),
  'text': "I see your temperature is 104 degrees. That's very high.",
  'utterance_id': 'U4'},
 {'speaker': 'Patient',
  'start': np.float64(22.24),
  'end': np.float64(27.2),
  'text': 'Yes, I feel very dizzy and nauseous. Did you get sick?',
  'utterance_id': 'U5'},
 {'speaker': 'Patient',
  'start': np.float64(28.34),
  'end': np.float64(35.0),
  'text': 'Yes, I vomited twice this morning. Did you have any diarrhea? Yes,',
  'utterance_id': 'U6'},
 {'speaker': 'Doctor

- SOAP Gen

In [None]:
dialogue = transcript2string(transcript_labelled)
summary = process_soap_traceability(dialogue, tok, tok_fast, ft_model)

The following generation flags are not valid and may be ignored: ['length_penalty']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


[INFO] Generated summary length: 792 characters, 164 tokens


In [None]:
summary

{'utterances': {'U1': 'Doctor: Hello, Mr. McKay.',
  'U2': 'Doctor: What brings you here today?',
  'U3': 'Patient: I have a fever and a sore stomach. Okay, Tony.',
  'U4': "Doctor: I see your temperature is 104 degrees. That's very high.",
  'U5': 'Patient: Yes, I feel very dizzy and nauseous. Did you get sick?',
  'U6': 'Patient: Yes, I vomited twice this morning. Did you have any diarrhea? Yes,',
  'U7': 'Doctor: a little bit. Did you take any medicine to treat your symptoms? No, doctor.',
  'U8': "Patient: I didn't take anything. Okay,",
  'U9': 'Doctor: sounds like you may have some food poisoning. Oh, no.',
  'U10': "Doctor: Take this medicine now and again every six hours until it's finished. You'll be okay. You'll be okay in about 24 hours. That's",
  'U11': 'Patient: a relief. Thank you very much, doctor. Thank you, doctor.'},
 'S': [{'sentence_idx': '0',
   'sentence_text': 'Patient, Tony McKay, presents with a fever and sore stomach.',
   'utterance_id': 'U3'},
  {'sentence_

In [None]:
conversation_id = insert_conversation(db, audio_id, transcript, transcript_word, summary)
print("Inserted conversation ID:", conversation_id)

Inserted conversation ID: 685a5bb17772f10be9ac8ac2
