In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Import

In [2]:
import json
import time
import os
import pandas as pd
import numpy as np
import re
import string
import unicodedata
import collections
from copy import deepcopy
import torch

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    GenerationConfig,
    TrainingArguments,
    Trainer,
    pipeline,
    PegasusForConditionalGeneration,
    PegasusTokenizer,
    PegasusTokenizerFast,
)

from peft import PeftModel, PeftConfig, get_peft_model_state_dict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cpu')

In [11]:
device

device(type='cpu')

### Utils

In [3]:
def transcript2string(utterances: list[dict]) -> str:
    dialogue_lines = [
        f"{entry['speaker']}: {entry['text'].strip()}"
        for entry in utterances
        if 'speaker' in entry and 'text' in entry
    ]
    return "\n".join(dialogue_lines)

In [4]:
def load_from_json(filename):
    """
    Load the conversation from a JSON file.
    """
    with open(filename, 'r') as f:
        return json.load(f)

## Model, Tokenizer

In [5]:
# fine_tune_path = '/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM'
# model_name = 'pegasus' # ADJUST
# sub_model_name = 'pegasus_xsum' # ADJUST
# checkpoints_dir = f"{fine_tune_path}/{model_name}/{sub_model_name}/lora_1" # ADJUST
# checkpoints_path = f"{checkpoints_dir}/checkpoints"
final_checkpoints_path = f"./final_checkpoints"
summary_path = f"./summary"

print(f"final_checkpoints_path: {final_checkpoints_path}")

final_checkpoints_path: ./final_checkpoints


In [12]:
# FT model
ft_base = AutoModelForSeq2SeqLM.from_pretrained(perf_config.base_model_name_or_path, return_dict=True, device_map=device)
ft_model = PeftModel.from_pretrained(ft_base, final_checkpoints_path)

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-xsum and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# PeftConfig
perf_config = PeftConfig.from_pretrained(final_checkpoints_path)

# Tokenizer
tok = AutoTokenizer.from_pretrained(final_checkpoints_path)
tok_fast   = PegasusTokenizerFast.from_pretrained(final_checkpoints_path)

# FT model
ft_base = AutoModelForSeq2SeqLM.from_pretrained(perf_config.base_model_name_or_path, return_dict=True, device_map=device)
ft_model = PeftModel.from_pretrained(ft_base, final_checkpoints_path)

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-xsum and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 428

gen_cfg = deepcopy(ft_model.generation_config)
gen_cfg.max_new_tokens = MAX_TARGET_LEN
gen_cfg.num_beams      = 1
gen_cfg.do_sample      = False
gen_cfg.early_stopping = False

gen_cfg

GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "forced_eos_token_id": 1,
  "length_penalty": 0.6,
  "max_length": 64,
  "max_new_tokens": 428,
  "pad_token_id": 0
}

# Generate

In [8]:
def get_response(dialogue, ft_model):
    prompt = f"""{dialogue}"""

    # 1. Tokenise input
    inputs = tok(
        prompt,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        max_length=MAX_SOURCE_LEN
    ).to(device)

    # 2. Inference: Fine-tuned model
    with torch.no_grad():
        gen_ids = ft_model.generate(**inputs, generation_config=gen_cfg)
    output = tok.decode(gen_ids[0], skip_special_tokens=True).strip()

    # 3. Debug info
    print(f"Input token length: {inputs['input_ids'].shape[1]}")

    return output

In [13]:
def process_soap_traceability(dialogue_name, dialogue, summary_path):
    # split into utterances + IDs --------------------------------------
    parts = re.split(r'(Doctor:|Patient:)', dialogue)[1:]
    utterances = [parts[i] + parts[i + 1] for i in range(0, len(parts), 2)]
    utterance_ids = [f"U{i+1}" for i in range(len(utterances))]
    # dict: utt - utt_id
    utt_dict = {uid: utt.strip() for uid, utt in zip(utterance_ids, utterances)}

    # --- map *encoder* tokens → utterance -----------------------------
    enc_inputs   = tok(dialogue, return_tensors="pt", max_length=512, truncation=True).to(device)
    enc_tokens   = tok.convert_ids_to_tokens(enc_inputs["input_ids"][0])

    token_to_utt = []
    cur_utt_idx  = 0
    utt_starts   = []
    for i, utt in enumerate(utterances):
        first_piece = tok.tokenize(utt.strip())[0]
        try:
            pos = enc_tokens.index(first_piece,
                                  utt_starts[-1] + 1 if i else 0)
        except ValueError:
            pos = utt_starts[-1] if utt_starts else 0
        utt_starts.append(pos)

    for i in range(len(enc_tokens)):
        j = max(j for j, start in enumerate(utt_starts) if i >= start)
        token_to_utt.append(utterance_ids[j])

    # ------------------------------------------------------------------
    # 3) GENERATE SOAP summary -----------------------------------------
    gen_ids = ft_model.generate(**enc_inputs, generation_config=gen_cfg)
    summary_text   = tok_fast.decode(gen_ids[0],
                                    skip_special_tokens=True,
                                    clean_up_tokenization_spaces=True)

    print(f"[INFO] Generated summary length: {len(summary_text)} characters, {len(gen_ids[0])} tokens")

    # Forced decoding to get attentions
    with torch.no_grad():
        outs = ft_model(**enc_inputs, labels=gen_ids, output_attentions=True)

    # (layers, b, h, dec_len, enc_len)
    all_xattn = torch.stack(outs.cross_attentions, dim=0)
    avg_xattn = all_xattn.mean(dim=(0, 2))[0]                    # (dec_len, enc_len)

    # ---------------------------------------------------------------
    # Decoder tokens without specials
    enc_summary = tok_fast(summary_text,
                          add_special_tokens=False,
                          return_offsets_mapping=True)

    dec_tokens  = tok_fast.convert_ids_to_tokens(enc_summary["input_ids"])
    offsets     = enc_summary["offset_mapping"]           # (dec_len, 2)

    # ---------------------------------------------------------------
    # Align rows ↔ dec_tokens 1-to-1
    bos_id, eos_id = tok_fast.bos_token_id, tok_fast.eos_token_id
    keep_mask_full = (gen_ids[0] != bos_id) & (gen_ids[0] != eos_id)
    keep_mask = keep_mask_full[:-1]  # Trim 1 element: drop the final position because EOS has **no** cross-attn row

    # avg_xattn = avg_xattn[keep_mask]
    # avg_xattn = avg_xattn[:len(dec_tokens)]
    avg_xattn = avg_xattn[:len(dec_tokens), :]
    # avg_xattn = avg_xattn[1:]

    if avg_xattn.shape[0] != len(dec_tokens):
        print("Warning: shape mismatch after filtering special tokens:",
              avg_xattn.shape[0], "rows vs", len(dec_tokens), "decoder tokens")

    tok_char_starts = [off[0] for off in offsets]


    # ------------------------------------------------------------------
    # 4) PARSE SOAP into sections + sentences --------------------------
    soap_sec_pat = re.compile(r'\b([SOAP]):')
    matches      = list(soap_sec_pat.finditer(summary_text))
    sent_entries = []                       # [{section,S/O/A/P; idx; start; end; text}]

    for idx, m in enumerate(matches):
        sec = m.group(1)
        start_txt = m.end()
        end_txt   = matches[idx+1].start() if idx+1 < len(matches) else len(summary_text)
        section_text = summary_text[start_txt:end_txt].strip()

        # split into sentences
        sent_pats = re.split(r'(?<=[.!?])\s+', section_text)
        cursor = start_txt
        local_idx = 0
        for s in sent_pats:
            s_clean = s.strip()
            if not s_clean:
                continue
            s_start = summary_text.find(s_clean, cursor, end_txt)
            s_end   = s_start + len(s_clean)
            sent_entries.append({
                "section":   sec,
                "sent_idx":  local_idx,
                "start":     s_start,
                "end":       s_end,
                "text":      s_clean,
            })
            cursor     = s_end + 1
            local_idx += 1

    # map decoder-token pos → sentence index ---------------------------
    tok2sent = []
    cur_sent = 0
    for pos in tok_char_starts:
        while (cur_sent + 1 < len(sent_entries) and
              pos >= sent_entries[cur_sent + 1]["start"]):
            cur_sent += 1
        tok2sent.append(cur_sent)

    # ------------------------------------------------------------------
    # 5) CROSS-ATTENTION -----------------------------------------------

    SPECIALS = set(tok_fast.all_special_tokens)
    SECTION_TOKENS = {"▁S", "▁O", "▁A", "▁P", "S", "O", "A", "P"}
    STOP_WORDS = {
        "the", "a", "an", "and", "or", "but", "if", "with", "of", "for", "to",
        "in", "on", "at", "by", "is", "are", "was", "were", "be", "been", "am",
        "it", "this", "that", "these", "those", "as", "has", "have", "had",
    }

    def all_punct(text: str) -> bool:
        """True if every character in text is Unicode punctuation."""
        return all(unicodedata.category(ch).startswith("P") for ch in text)

    def is_skip(tok):
        t = tok.replace('▁', '').lower()

        return (
            t in string.punctuation        # ← punctuation
            or t in STOP_WORDS             # ← stop-word list you defined
            or tok in SPECIALS             # ← <pad>, </s>, <s> …
            or tok in SECTION_TOKENS
            or t == ""                     # ← empty after stripping ▁
            or t in {"", "<pad>", "</s>"}
            or all_punct(t)
        )

    top_k_token_att = 5 # keep k strongest attentions
    # ------------------------------------------------------------------
    # 6) TOKEN-LEVEL vote (top-k enc tokens) ---------------------------
    token_vote_utt_dict = {}      # dec_pos -> winner utterance (e.g. 3: "U1")
    token_top_utts      = {}      # dec_pos -> [k] utterances  (e.g. 3: ['U1', 'U1', 'U1']) # debug only

    for dpos, scores in enumerate(avg_xattn):            # dpos indexes dec_tokens
        if is_skip(dec_tokens[dpos]):                    # skip punc / stop-words / specials
            continue                                     #  ↳ not stored anywhere


        # Sort encoder tokens by attention scores, take top-k non-skipped ones.
        keep = []
        for idx in torch.argsort(scores, descending=True):
            if not is_skip(enc_tokens[idx]):
                keep.append(int(idx))
                if len(keep) == top_k_token_att:         # got k, stop
                    break
        if not keep:                                     # Rare case: all top-k are skipped
            continue

        # Map encoder token indices to utterance IDs using token_to_utt.
        utts   = [token_to_utt[i] for i in keep]
        win    = collections.Counter(utts).most_common(1)[0][0]

        token_vote_utt_dict[dpos] = win
        token_top_utts[dpos]      = utts                # debug only

    # ------------------------------------------------------------------
    # 7) SENTENCE-LEVEL vote -------------------------------------------
    # Map decoder token → sentence index
    tok2sent = []
    cur_sent = 0
    for start_char, _ in offsets:                        # offsets length == dec_tokens
        while cur_sent + 1 < len(sent_entries) and start_char >= sent_entries[cur_sent + 1]["start"]:
            cur_sent += 1
        tok2sent.append(cur_sent)

    # Majority vote for each sentence
    for global_idx, s in enumerate(sent_entries):
        winners = [token_vote_utt_dict[d]
            for d, sid in enumerate(tok2sent)
                if sid == global_idx and d in token_vote_utt_dict]
        s["utterance_id"] = (
            collections.Counter(winners).most_common(1)[0][0] if winners else "Unknown"
        )

    # ------------------------------------------------------------------
    # 8) BUILD FINAL JSON ----------------------------------------------
    final = {"utterances": utt_dict,        # 🔹NEW section
            "S": [], "O": [], "A": [], "P": []}

    for sent in sent_entries:
        final[sent["section"]].append({
            "sentence_idx":  str(sent["sent_idx"]),
            "sentence_text": sent["text"],
            "utterance_id":  sent["utterance_id"],
        })

    json_path = f"{summary_path}/summary_{dialogue_name}.json"
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(final, f, ensure_ascii=False, indent=2)

    print(f"Saved {json_path}")

In [15]:
dialogue_name = "fever_stomach"
input_path = f"../data/labelled_{dialogue_name}.json"
transcript = load_from_json(f"{input_path}")
dialogue = transcript2string(transcript)
summary_path = "../data/summary"
process_soap_traceability(dialogue_name, dialogue, summary_path)

The following generation flags are not valid and may be ignored: ['length_penalty']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[INFO] Generated summary length: 792 characters, 164 tokens
Saved ../data/summary/summary_fever_stomach.json
