## Utility functions

In [None]:
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Mapping, Sequence, Union
import csv
import difflib
import itertools
import json
import math
import os
import re
import string
import textwrap
import unicodedata
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import nltk
import numpy as np
import plotly.graph_objects as go
from scipy.signal import savgol_filter
from sentence_transformers import SentenceTransformer
from span_marker import SpanMarkerModel
import spacy
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from gliner import GLiNER
import math
from typing import Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
import pandas as pd
import torch
from scipy.stats import entropy
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

nlp = spacy.load("it_core_news_lg")


In [None]:
def save_paragraphs(paragraphs: Sequence[str], path: Path = CSV_PATH) -> None:

    with path.open("w", encoding="utf8", newline="") as fh:
        writer = csv.writer(fh, quoting=csv.QUOTE_MINIMAL)
        writer.writerow(("ID", "paragraph"))
        for idx, text in enumerate(paragraphs):
            writer.writerow((idx, text))


def load_embedding_model(model_name, device="cuda"):
    return SentenceTransformer(model_name, device=device, trust_remote_code=True)


class SentenceDataset(Dataset):
    def __init__(self, sentences):
        self.sentences = sentences

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.sentences[idx]

def save_paragraph_embeddings(embeds: torch.Tensor, path: str = "paragraphs_embeds.pt") -> None:
    torch.save(embeds, path)
    print(f"Embeddings saved to {path}")


def sankey_to_json(labels, src, tgt, val, colors=None, *,
                   indent=2, file=None):



    nodes = [
        {"id": i, "label": lbl, **({"color": colors[i]} if colors else {})}
        for i, lbl in enumerate(labels)
    ]

    links = [
        {"source": labels[s], "target": labels[t], "value": float(v)}
        for s, t, v in zip(src, tgt, val)
    ]

    data = {"nodes": nodes, "links": links}
    js = json.dumps(data, indent=indent, ensure_ascii=False)


    if file:
        with open(file, "w", encoding="utf-8") as f:
            f.write(js)
    else:
        print(js)

    return js


def load_stopwords(file_path: str = 'stopwords.txt') -> set:
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            stopwords = {line.strip().lower() for line in f if line.strip()}
        return stopwords
    except FileNotFoundError:
        print(f"Warning: Stopwords file {file_path} not found. Using spaCy's default stopwords.")
        return None


In [None]:

CSV_PATH = Path("paragraphs.csv")
model_name = "BAAI/bge-m3" # model for sentence embedding
sankey_name ='sankey_authrorial_ALL'
batch_size = 16

## Load and pre-process input documents

In [None]:
def load_and_preprocess_docx(file_path):
    document = Document(file_path)
    full_text = []
    # Normalize character names and fix formatting issues in the text
    _pattern_be = re.compile(r'(\s*)JacopoBelbo(\s*)')
    _pattern_ab = re.compile(r'(\s*)AbrahamAbulafia(\s*)')
    for para in document.paragraphs:
        text = para.text
        text = text.replace("- ", "").replace(" -", "").replace(" - ", "")
        text = _pattern_be.sub(r'\1Jacopo Belbo\2', text)
        text = _pattern_ab.sub(r'\1Abraham Abulafia\2', text)
        text = text.replace("Rosa-Croce","Rosacroce")
        full_text.append(text)
    full_text= "\n".join(full_text)
    full_text = full_text.strip('\n')
    return full_text




def segment_text_into_paragraphs(text: str) -> list:
    text = re.sub(r'-\s*\n\s*', '', text)

    text = re.sub(r'(\w+)\s*\n\s*(\w+)', r'\1 \2', text)

    text = re.sub(r'(?<!\n)\n(?!\n)', ' ', text)


    paragraphs = re.split(r'(?:\n{2,}|\s{4,}|[-_*#]{3,})', text)

    paragraphs = [para.strip() for para in paragraphs if para.strip()]

    return paragraphs




In [None]:
root='./'
import torch

#Each chapter must be placed in its own folder, containing one file for each section of the chapter.
#For example: ./Chapter1_KETEL/C1_1.docx for Chapter 1 (KETEL), section 1.

chapter_names = {1:'_KETEL', 2:'_HOKMAH', 3:'_BINA H', 4:'_HESED'}
chapter_paragraphs = {}
chapter_length = {}
chapter_inter_length={}
for n_c in chapter_names:
    chapter = f'Chapter{n_c}'
    title = chapter_names[n_c]
    lengths = []
    docs = []
    count = 0
    i = 1
    while True:

        if os.path.exists(f'{root}{chapter}{title}/C{n_c}_{i}.docx'):
            print(f'{root}{chapter}{title}/C{n_c}_{i}.docx')
            doc1 = load_and_preprocess_docx(f'{root}{chapter}{title}/C{n_c}_{i}.docx')
            doc2 =segment_text_into_paragraphs(doc1)
            if len(doc2)<=0: continue
            doc1 = []
            for dd in doc2:
              if len(dd)>0:
                doc1.append(dd.replace('\n','').strip())
            print(len(doc1))
            count += len(doc1)
            lengths.append(len(doc1))
            previous = len(doc1)
            docs.extend(doc1)
        else:
            break
        i += 1
    chapter_length[n_c] = count
    chapter_inter_length[n_c] = lengths
    doc = " ".join(docs)
    paragraphs = docs
    chapter_paragraphs[n_c] = paragraphs
    torch.cuda.empty_cache()
n_c = 1


## Load topics

In [None]:
with open("topics_o.json", "r", encoding="utf-8") as f:
    topics = json.load(f)


## Model definition

In [None]:
  _STOP = set(load_stopwords())

In [None]:
use_span_maker = False  # if True we a SpanMaker model for NER, GLiNER otherwise

In [None]:
if use_span_maker:

  _MODEL = "tomaarsen/span-marker-mbert-base-multinerd"
  _mod = SpanMarkerModel.from_pretrained(_MODEL)
  if torch.cuda.is_available():
      _mod = _mod.to("cuda")

  def _ner(text: str):
      return _mod.predict(text)

  MAP = {
      "PER":  "character",
      "ORG":  "organization",
      "LOC":  "space",
      "TIME": "time",
      "ANIM": "object", "BIO": "object", "CEL": "object", "DIS": "object",
      "EVE": "object", "FOOD": "object", "INST": "object", "MEDIA": "object",
      "PLANT": "object", "MYTH": "object", "VEHI": "object",
  }


  _RE_DAYMONTH = re.compile(
      r"\b\d{1,2}\s(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|"
      r"agosto|settembre|ottobre|novembre|dicembre)\b",
      re.I,
  )
  _RE_YEAR_RANGE_SHORT = re.compile(r"\b(?P<y1>1[3-9]\d{2}|20\d{2})\s*[-–—]\s*(?P<y2>\d{2})\b")
  _RE_YEAR_RANGE_FULL  = re.compile(r"\b(1[3-9]\d{2}|20\d{2})\s*[-–]\s*(1[3-9]\d{2}|20\d{2})\b")
  _RE_CENT_RANGE = re.compile(r"\b(?:secoli?|sec\.?)\s+([IVXLCDM]+)\s*[-–—]\s*([IVXLCDM]+)\b", re.I)
  _RE_YEAR = re.compile(r"\b(1[3-9]\d{2}|20\d{2})\b")
  _RE_CENT = re.compile(r"\b(?:secolo|sec\.?)\s+([IVXLCDM]+|\d{1,2})\b", re.I)



  def _cleanx(tok: str) -> str:
      tok = tok.replace("##", "").replace(" ", "")
      return tok.strip(string.punctuation + "-–—")

  def _good_token(tok: str, min_len: int = 3) -> bool:
      if len(tok) < min_len:
          return False
      if tok.lower() in _STOP:
          return False
      if any(ch.isdigit() for ch in tok):
          return False
      if sum(ch in "aeiouAEIOU" for ch in tok) < 2:
          return False
      return True

  def _dedup_filter(lst: List[str], df: int = 1) -> List[str]:
      cnt = Counter(lst)
      return [t for t in cnt if cnt[t] >= df and _good_token(t)]

  def param_extractor(text: str, *, df_threshold: int = 1) -> Dict[str, List[str]]:
    buck = defaultdict(list)

      for ent in _ner(text):
          lbl = ent["label"]
          cat = _MAP.get(lbl)
          if not cat:
              continue
          tok = _cleanx(ent["span"])
          if _good_token(tok):
              buck[cat].append(tok)

      buck["time"].extend(_RE_DAYMONTH.findall(text))
      buck["time"].extend(_RE_YEAR.findall(text))
      buck["time"].extend(_RE_CENT.findall(text))

      for m in _RE_YEAR_RANGE_SHORT.finditer(text):
          y1, y2 = int(m.group("y1")), int(m.group("y2"))
          y2_full = y1 // 100 * 100 + y2
          buck["time"].append(f"{y1}-{y2_full}")

  for m in _RE_YEAR_RANGE_FULL.finditer(text):
          buck["time"].append(f"{m.group(1)}-{m.group(2)}")

      for m in _RE_CENT_RANGE.finditer(text):
          buck["time"].append(f"{m.group(1)}-{m.group(2)}")

      out = {}
      for k, toks in buck.items():
          toks = _dedup_filter(toks, df_threshold)
          if toks:
              out[k] = sorted(set(toks))

      return out




In [None]:


if not use_span_maker:

  _MODEL = "urchade/gliner_multi-v2.1"
  _mod = GLiNER.from_pretrained(_MODEL)
  if torch.cuda.is_available():
      try:
          _mod.model = _mod.model.to("cuda")
      except Exception:
          pass

  def _ner(text: Union[str, List[str]]):
      _LABELS = ["PER","ORG","LOC","TIME","ANIM","BIO","CEL","DIS","EVE","FOOD","INST","MEDIA","PLANT","MYTH","VEHI"]
      def _convert(ents):
          out = []
          for e in ents:
              t = e.get("text")
              l = e.get("label")
              if t and l:
                  out.append({"span": t, "label": l})
          return out
      if isinstance(text, str):
          ents = _mod.predict_entities(text, _LABELS)
          return _convert(ents)
      elif isinstance(text, list):
          results = []
          for t in text:
              ents = _mod.predict_entities(t, _LABELS)
              results.extend(_convert(ents))
          return results


  # REGEX
  _RE_DAYMONTH = re.compile(
      r"\b\d{1,2}\s(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|"
      r"agosto|settembre|ottobre|novembre|dicembre)\b",
      re.I,
  )
  _RE_YEAR_RANGE_SHORT = re.compile(r"\b(?P<y1>1[3-9]\d{2}|20\d{2})\s*[-–—]\s*(?P<y2>\d{2})\b")
  _RE_YEAR_RANGE_FULL  = re.compile(r"\b(1[3-9]\d{2}|20\d{2})\s*[-–]\s*(1[3-9]\d{2}|20\d{2})\b")
  _RE_CENT_RANGE = re.compile(r"\b(?:secoli?|sec\.?)\s+([IVXLCDM]+)\s*[-–—]\s*([IVXLCDM]+)\b", re.I)
  _RE_YEAR = re.compile(r"\b(1[3-9]\d{2}|20\d{2})\b")
  _RE_CENT = re.compile(r"\b(?:secolo|sec\.?)\s+([IVXLCDM]+|\d{1,2})\b", re.I)

  # Helper methods
  def _cleanx(tok: str) -> str:
      tok = tok.replace("##", "").replace(" ", "")
      return tok.strip(string.punctuation + "-–—")

  def _good_token(tok: str, min_len: int = 3) -> bool:
      if len(tok) < min_len:
          return False
      if tok.lower() in _STOP:
          return False
      if any(ch.isdigit() for ch in tok):
          return False
      if sum(ch in "aeiouAEIOU" for ch in tok) < 2:
          return False
      return True

  def _dedup_filter(lst: List[str], df: int = 1) -> List[str]:
      cnt = Counter(lst)
      return [t for t in cnt if cnt[t] >= df and _good_token(t)]

  def param_extractor(text: str, *, df_threshold: int = 1) -> Dict[str, List[str]]:
      buck = defaultdict(list)
      for ent in _ner(text):
          lbl = ent["label"]
          cat = _MAP.get(lbl)
          if not cat:
              continue
          tok = _cleanx(ent["span"])
          if _good_token(tok):
              buck[cat].append(tok)
      buck["time"].extend(_RE_DAYMONTH.findall(text))
      buck["time"].extend(_RE_YEAR.findall(text))
      buck["time"].extend(_RE_CENT.findall(text))
      for m in _RE_YEAR_RANGE_SHORT.finditer(text):
          y1, y2 = int(m.group("y1")), int(m.group("y2"))
          y2_full = y1 // 100 * 100 + y2
          buck["time"].append(f"{y1}-{y2_full}")
      for m in _RE_YEAR_RANGE_FULL.finditer(text):
          buck["time"].append(f"{m.group(1)}-{m.group(2)}")
      for m in _RE_CENT_RANGE.finditer(text):
          buck["time"].append(f"{m.group(1)}-{m.group(2)}")
      out = {}
      for k, toks in buck.items():
          toks = _dedup_filter(toks, df_threshold)
          if toks:
              out[k] = sorted(set(toks))
      return out


### Lexical Topic Strength

In [None]:

class LexicalTopicStrength:

    def __init__(
        self,
        topics: Union[Dict[str, List[str]], List[str]],
        *,
        preprocess: bool = True ) -> None:

        self.topics = topics
        self.preprocess = bool(preprocess) and spacy is not None

        self.order_keys = (
            sorted(topics) if isinstance(topics, dict) else ["topic"]
        )

        if self.preprocess:
            self._nlp = spacy.load("it_core_news_sm", disable=["parser", "ner"])
            if isinstance(topics, dict):
                self.topics = {
                    k: [self._pre_text(t) for t in topics[k]] for k in self.order_keys
                }
            else:
                self.topics = [self._pre_text(t) for t in topics]

        self._prepare_corpus()



    @staticmethod
    def _tok(text: str) -> List[str]:

        return re.findall(r"\w+", text.lower())

    def _pre_text(self, text: str) -> str:

        doc = self._nlp(text)
        return " ".join(
            tok.lemma_.lower()
            for tok in doc
            if not (tok.is_stop or tok.is_punct or tok.is_space)
        )

    def _prepare_corpus(self) -> None:

        self.topic_docs: Dict[str, str] = {}
        self.tf_topic: Dict[str, Counter] = {}
        vocab = set()

        if isinstance(self.topics, dict):
            for cls in self.order_keys:
                doc = " ".join(self.topics[cls])
                self.topic_docs[cls] = doc
                cnt = Counter(self._tok(doc))
                self.tf_topic[cls] = cnt
                vocab.update(cnt)
        else:
            doc = " ".join(self.topics)
            self.topic_docs["topic"] = doc
            cnt = Counter(self._tok(doc))
            self.tf_topic["topic"] = cnt
            vocab.update(cnt)

        N = len(self.topic_docs)
        df = Counter()
        for tok in vocab:
            df[tok] = sum(tok in d for d in self.topic_docs.values())
        self.idf = {tok: math.log((N + 1) / (df_ + 1)) + 1 for tok, df_ in df.items()}
        self._vocab = vocab



    def get_vocab(self) -> set:

        return self._vocab

    @staticmethod
    def _softmax(X: np.ndarray, t: float) -> np.ndarray:

        X_stab = X - np.max(X, axis=1, keepdims=True)
        eX = np.exp(X_stab / max(t, 1e-8))
        return eX / np.sum(eX, axis=1, keepdims=True)


    def compute_lexical_scores(
        self,
        paragraphs: List[str],
        *,
    ) -> np.ndarray:

        m = len(paragraphs)
        k = len(self.order_keys)
        S = np.zeros((m, k), dtype=float)

        for i, para in enumerate(paragraphs):
            print(para)
            theta = param_extractor(para)
            if not theta:
                continue
            theta_set = {
                tok.lower() for lst in theta.values() for tok in lst if isinstance(tok, str)
            }

            for j, cls in enumerate(self.order_keys):
                tf = self.tf_topic[cls]
                s = 0.0
                for tok in theta_set:
                    tf_tok = tf.get(tok, 0)
                    if tf_tok:
                        s += tf_tok * self.idf.get(tok, 0.0)
                S[i, j] = s


        return S


### Semantic Topic Strength

In [None]:
class SemanticTopicStrength:

    def __init__(self, topics, model_name, device="cuda", batch_size=16):
        self.topics = topics
        self.model_name = model_name
        self.device = device
        self.model = SentenceTransformer(self.model_name, device=device, trust_remote_code=True)
        self.batch_size = batch_size
        if isinstance(self.topics, dict):
            self.order_keys = sorted(self.topics.keys())

        if isinstance(self.topics, dict):
            self.topic_embeddings, self.topic_length = self._compute_topic_embeddings(self.batch_size)
        else:
            self.topic_embeddings = self.compute_block_embeddings()
            self.topic_length = 1

    def _encode_batch(self, sentences):
        with torch.no_grad():
            embs = self.model.encode(sentences, convert_to_numpy=True)
        return embs

    def split_into_chunks(self, text, max_chunk_chars=500):
        sentences = re.split(r'(?<=\.)\s+', text.strip())
        chunks = []
        current_chunk = ""

        for sentence in sentences:
            if current_chunk:
                if len(current_chunk) + len(sentence) + 1 <= max_chunk_chars:
                    current_chunk += " " + sentence
                else:
                    chunks.append(current_chunk.strip())
                    current_chunk = sentence
            else:
                current_chunk = sentence

        if current_chunk:
            chunks.append(current_chunk.strip())

        return chunks

    def compute_block_embeddings(self, max_chunk_chars=500):
        all_topic_embeddings = []

        for topic in self.topics: # list of par
            chunks = self.split_into_chunks(topic, max_chunk_chars)
            chunk_embeddings = self.model.encode(chunks)
            topic_avg_embedding = np.mean(chunk_embeddings, axis=0)
            all_topic_embeddings.append(topic_avg_embedding)
        overall_avg_embedding = np.mean(all_topic_embeddings, axis=0)
        return overall_avg_embedding

    def _compute_topic_embeddings(self, batch_size=16):
        all_embs = []
        topic_length = []
        for tname in self.order_keys:
            phrases = self.topics[tname]
            topic_length.append(len(phrases))
            dataset = SentenceDataset(phrases)
            loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
            for batch in loader:
                emb = self._encode_batch(batch)
                all_embs.append(emb)
        emb_matrix = np.concatenate(all_embs, axis=0)
        return emb_matrix, topic_length

    def compute_centroids(self, keyword_embedding, topic_length):
        topic_vectors = []
        idx = 0
        for length in topic_length:
            vecs = keyword_embedding[idx : idx + length]
            idx += length
            if vecs.shape[0] == 1:
                topic_vector = vecs[0]
            else:
                centroid = np.mean(vecs, axis=0)
                topic_vector = centroid
            topic_vectors.append(topic_vector)
        topic_vectors = np.vstack(topic_vectors)
        return topic_vectors

    def compute_paragraph_embeddings(self, paragraphs):
        return np.array(self.model.encode(paragraphs))

    def cosine_similarity_matrix(self, A_np, B_np):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        A = torch.tensor(A_np, device=device, dtype=torch.float32)
        B = torch.tensor(B_np, device=device, dtype=torch.float32)
        A = A / A.norm(dim=1, keepdim=True)
        B = B / B.norm(dim=1, keepdim=True)
        sim = torch.mm(A, B.t())
        return sim.cpu().numpy()

    def combine_blocks(self,
            M: np.ndarray,
            block_boundaries: list,
            transform='relu',
            penalty='log'
        ):
        num_rows = M.shape[0]
        num_blocks = len(block_boundaries)

        if transform == 'relu':
            def f(x): return np.maximum(x, 0.0)
        elif transform == 'linear':
            def f(x): return (x + 1.) / 2.
        elif transform == 'identity':
            def f(x): return x
        else:
            raise ValueError("Invalid transform!")

        out = np.zeros((num_rows, num_blocks))
        idx = 0

        for j, length in enumerate(block_boundaries):
            K_j = length
            start = idx
            end = length + idx
            idx += length

            if penalty == 'none':
                w_j = 1.0
            elif penalty == 'inverse':
                w_j = 1.0 / K_j
            elif penalty == 'inverse_sqrt':
                w_j = 1.0 / np.sqrt(K_j)
            elif penalty == 'log':
                w_j = 1.0 / np.log(K_j + 1.0)
            elif penalty == 'custom':
                alpha = 0.75
                w_j = 1.0 / (K_j ** alpha)
            else:
                raise ValueError("Invalid penalty!")

            block_slice = M[:, start:end]
            block_slice_transformed = f(block_slice)
            block_mean = block_slice_transformed.mean(axis=1)
            out[:, j] = block_mean * w_j

        return out

    def compute_semantic_scores(self, paragraphs):
        if isinstance(self.topics, dict):
            paragraph_embeddings = self.compute_paragraph_embeddings(paragraphs)
            topic_vectors = self.cosine_similarity_matrix(paragraph_embeddings, self.topic_embeddings)
            scores = self.combine_blocks(topic_vectors, self.topic_length, transform='identity', penalty='log')
            return scores
        else:
            paragraph_embeddings = self.compute_paragraph_embeddings(paragraphs)
            scores = self.cosine_similarity_matrix(paragraph_embeddings, self.topic_embeddings.reshape(1, -1))
            return scores




### ThematicStrength

In [None]:

class ThematicStrength:
    def __init__(self, topics, model_name, paragraphs, device="cuda", batch_size=16):
        self.topics = topics
        self.model_name = model_name
        self.device = device
        self.lexical_strength = LexicalTopicStrength(topics, preprocess=True)
        self.semantic_strength = SemanticTopicStrength(topics, model_name, device, batch_size=batch_size)
        self.model = self.semantic_strength.model
        self.sentence_embeddings = None
        self.lex_dist = self.lexical_strength.compute_lexical_scores(paragraphs)
        self.sem_dist = self.semantic_strength.compute_semantic_scores(paragraphs)

    def compute_thematic_strength(self, paragraphs, return_dict=False, ni=0.5):
        lex_dist = self.lex_dist
        sem_dist = self.sem_dist
        combined = (1-ni) * lex_dist + ni*sem_dist

        if return_dict:
            topic_list= list(self.topics.keys())
            out = {}
            for i, row in enumerate(combined):
                tmp = {}
                for k, tname in enumerate(topic_list):
                    tmp[tname] = float(row[k])
                out[i] = tmp
            return out
        return combined



### TopicFlowEquation

In [None]:
def _minmod(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    out = torch.zeros_like(a)
    mask = (a * b) > 0
    out[mask] = torch.where(torch.abs(a[mask]) < torch.abs(b[mask]), a[mask], b[mask])
    return out



class FluidDynamicTopicFlow:



    def __init__(
        self,
        topics,
        paragraphs,
        *,
        sentence_embeddings: np.ndarray,
        model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
        device: str = "cuda",
        batch_size: int = 16,
        k: Union[int, str] = "auto",
        similarity_threshold: float = 0.8,
        timeral_sigma: float = 100.0,
        directed: bool = True,
        viscosity: float = 0.05,
        viscosity_degree_exponent: float = 1.0,
        edge_velocity_degree_damping: bool = True,
        edge_velocity_scale: float = 1.0,
        lambda_decay: float = 0.0,
        lambda_inject: float = 0.0,
        dt: float = 0.1,
        flux_limiter: bool = True,
        min_iter: int = 20,
        tol: float = 1e-5,
    ) -> None:
        if device == "cuda" and not torch.cuda.is_available():
            device = "cpu"
        self.device = torch.device(device)

        self.embedder = ThematicStrength(
            topics=topics,
            model_name=model_name,
            paragraphs=paragraphs,
            device=device,
            batch_size=batch_size
        )

        self.topics = topics
        self.paragraphs = paragraphs
        self.embeddings = sentence_embeddings
        self.initial_strength = torch.tensor(
            self.embedder.compute_thematic_strength(paragraphs, ni=0.5, use_adaptive=False),
            dtype=torch.float32,
            device=self.device,
        )

        assert torch.isfinite(self.initial_strength).all(), "initial_strength contains NaN/Inf"
        assert not np.isnan(self.embeddings).any(),        "sentence_embeddings contains NaN"


        self.k = k
        self.similarity_threshold = similarity_threshold
        self.timeral_sigma = timeral_sigma
        self.directed = directed

        self.viscosity = viscosity
        self.vdeg_exp = viscosity_degree_exponent
        self.edge_velocity_degree_damping = edge_velocity_degree_damping
        self.edge_velocity_scale = edge_velocity_scale

        self.lambda_decay = lambda_decay
        self.lambda_inject = lambda_inject

        self.dt_requested = dt
        self.flux_limiter = flux_limiter
        self.min_iter = min_iter
        self.tol = tol


        self.num_paragraphs, self.num_topics = self.initial_strength.shape
        self.graph = self._build_graph()


        self.evolved_strength = self.initial_strength.clone()
        self._history_strength: List[torch.Tensor] = [self.initial_strength.detach().cpu()]
        self._history_adv_flux: List[torch.Tensor] = []
        self._history_diff_flux: List[torch.Tensor] = []

        self.meta = self.MetaNarrative(self)


    def _timeral_kernel(self, i: int, j: int) -> float:
        return math.exp(-((j - i) ** 2) / (2.0 * self.timeral_sigma ** 2))

    def _auto_k(self) -> int:
        return max(4, min(50, int(round(math.sqrt(self.num_paragraphs)))))

    def _build_graph(self) -> nx.DiGraph | nx.Graph:
        G = nx.DiGraph() if self.directed else nx.Graph()
        k_val = self._auto_k() if (isinstance(self.k, str) and self.k == "auto") else int(self.k)

        sim_matrix = cosine_similarity(self.embeddings)
        for i in range(self.num_paragraphs):
            G.add_node(i)
            weights = sim_matrix[i] * np.array([self._timeral_kernel(i, j) for j in range(self.num_paragraphs)])
            weights[i] = 0.0
            for j in weights.argsort()[::-1][:k_val]:
                w = float(weights[j])
                if w < self.similarity_threshold:
                    continue
                if self.directed and j > i:
                    G.add_edge(i, j, weight=w)
                    G.add_edge(j, i, weight=w * 0.1)
                else:
                    G.add_edge(i, j, weight=w)
            if i < self.num_paragraphs - 1:
                G.add_edge(i, i + 1, weight=0.05)
        return G

    def _rhs(self, S: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if not torch.isfinite(S).all():
            raise RuntimeError("NaN/Inf in S all'ingresso di _rhs")

        degrees = torch.tensor(
            [self.graph.degree(v) for v in range(self.num_paragraphs)],
            dtype=torch.float32,
            device=self.device,
        )
        nu_node = self.viscosity / (1.0 + degrees) ** self.vdeg_exp

        adv_flux = torch.zeros_like(S)
        diff_flux = torch.zeros_like(S)


        VEL_MAX = 2.0
        edge_vel: Dict[Tuple[int, int], float] = {}

        def _safe_cos(u: np.ndarray, v: np.ndarray, eps: float = 1e-12) -> float:
            num = float(np.dot(u, v))
            den = float(np.linalg.norm(u) * np.linalg.norm(v))
            return 0.0 if den < eps else num / den

        for i, j, w in self.graph.edges.data("weight"):
            cos_sim = _safe_cos(self.embeddings[i], self.embeddings[j])
            v_abs = w * cos_sim * self.edge_velocity_scale
            if self.edge_velocity_degree_damping:
                v_abs /= 1.0 + max(self.graph.degree(i), self.graph.degree(j))
            if not math.isfinite(v_abs):
                v_abs = 0.0

            v_abs = max(-VEL_MAX, min(VEL_MAX, v_abs))
            edge_vel[(i, j)] = v_abs


        node_flux = torch.zeros_like(S)
        for i, j in self.graph.edges:
            v = edge_vel[(i, j)]
            if v == 0.0:
                continue
            delta = S[i] - S[j]
            flux = v * delta
            if self.flux_limiter:
                delta_next = (S[j] - S[i + 1]) if (i + 1) < self.num_paragraphs else torch.zeros_like(delta)
                flux = _minmod(flux, v * delta_next)
            node_flux[i] -= flux
            node_flux[j] += flux
        adv_flux.copy_(node_flux)

        lap = torch.zeros_like(S)
        for i in range(self.num_paragraphs):
            neigh = list(self.graph.neighbors(i))
            if not neigh:
                continue
            neigh_tensor = torch.stack([S[j] for j in neigh])
            lap_i = torch.mean(neigh_tensor, dim=0) - S[i]
            lap[i] = lap_i
            diff_flux[i] = nu_node[i] * lap_i
        diff_term = diff_flux.clone()


        reaction = self.lambda_decay * (self.initial_strength - S) - self.lambda_inject * lap

        F = adv_flux + diff_term + reaction
        return F, adv_flux, diff_term

    def fluid_propagation(self, *, max_iter: int = 400,
                          clip: Tuple[float, float] | None = (0.0, 1.0)) -> torch.Tensor:

        def _soft_upper(x: torch.Tensor, high: float, delta: float = 0.15) -> torch.Tensor:
            return torch.where(
                x <= high - delta,
                x,
                high - delta + delta * torch.tanh((x - (high - delta)) / delta),
            )

        S = self.evolved_strength.clone()
        diff_ratio = float("inf")
        CFL_SAFETY = 0.45

        for it in tqdm(range(max_iter)):

            F1, adv1, diff1 = self._rhs(S)
            vmax = float(torch.max(torch.abs(F1))) + 1e-12
            dt = min(self.dt_requested, CFL_SAFETY / vmax)

            S_pred = S + dt * F1
            if clip:
                S_pred = torch.clamp_min(S_pred, clip[0])
                if clip[1] is not None:
                    S_pred = _soft_upper(S_pred, clip[1])


            F2, adv2, diff2 = self._rhs(S_pred)
            S_new = S + 0.5 * dt * (F1 + F2)
            if not torch.isfinite(S_new).all():
                raise RuntimeError(
                    f"NaN/Inf detected at iteration {it + 1}. "
                    "Controlla stabilità numerica e parametri."
                )
            if clip:
                S_new = torch.clamp_min(S_new, clip[0])
                if clip[1] is not None:
                    S_new = _soft_upper(S_new, clip[1])

            self._history_strength.append(S_new.detach().cpu())
            self._history_adv_flux.append(((adv1 + adv2) * 0.5).detach().cpu())
            self._history_diff_flux.append(((diff1 + diff2) * 0.5).detach().cpu())


            diff_ratio = torch.median(torch.abs(S_new - S) / (torch.abs(S) + 1e-9)).item()
            print(f"[iter {it + 1:03d}]  median ΔS/S = {diff_ratio:.3e}")
            if (it + 1) >= self.min_iter and diff_ratio < self.tol:
                break
            S = S_new

        self.evolved_strength = S.detach()
        print(f"Iterations: {it + 1} — median ΔS/S = {diff_ratio:.2e}")
        return S


    class MetaNarrative:
        def __init__(self, outer: "FluidDynamicTopicFlow") -> None:
            self._o = outer

        def _mat(self, evolved: bool, step: Optional[int]):
            if step is not None:
                return self._o._history_strength[step].to(self._o.device)
            return self._o.evolved_strength if evolved else self._o.initial_strength

        def intensity(self, evolved: bool = True, step: Optional[int] = None):
            return torch.sum(self._mat(evolved, step), dim=1).cpu().numpy()

        def entropy(self, evolved: bool = True, step: Optional[int] = None, base: int = 2):
            S = self._mat(evolved, step)
            P = S / (torch.sum(S, dim=1, keepdim=True) + 1e-9)
            return entropy(P.cpu().numpy(), axis=1, base=base)

        def divergence(self, step: Optional[int] = None):
            sel = self._o._history_adv_flux[-1] if step is None else self._o._history_adv_flux[step]
            return torch.sum(sel, dim=1).numpy()

        def shear_index(self, step: Optional[int] = None):
            adv = self._o._history_adv_flux[-1] if step is None else self._o._history_adv_flux[step]
            diff = self._o._history_diff_flux[-1] if step is None else self._o._history_diff_flux[step]
            num = torch.sum(torch.abs(adv), dim=1)
            den = num + torch.sum(torch.abs(diff), dim=1) + 1e-9
            return (num / den).numpy()

        def latent_modes(self, n_components: int = 5):
            resid = (self._o.initial_strength - self._o.evolved_strength).cpu().numpy()
            if np.max(np.abs(resid)) < 1e-8:
                return np.zeros((self._o.num_paragraphs, n_components))
            pca = PCA(n_components=min(n_components, resid.shape[1]))
            return pca.fit_transform(resid)

        def _rank(self, vec: np.ndarray, top_k: int):
            idx = np.argsort(vec)[::-1][:top_k]
            return [(int(i), float(vec[i])) for i in idx]

        def summary_report(self, top_k: int = 5):
            return {
                "top_intense": self._rank(self.intensity(), top_k),
                "most_diverse": self._rank(self.entropy(), top_k),
                "divergence": self._rank(np.abs(self.divergence()), top_k),
                "shear": self._rank(self.shear_index(), top_k),
            }




## Run model

In [None]:
text_segments = [item for sublist in chapter_paragraphs.values() for item in sublist] # all paragraphs
print(len(text_segments))
doc = ' '.join(text_segments)
paragraphs = text_segments
save_paragraphs(paragraphs)


In [None]:
device = 'cuda'
model_safe = model_name.split('/')[-1]
if not os.path.isfile(f'sentence_embedding_{model_name}.pt'):
  dataset = SentenceDataset(paragraphs)

  model = load_embedding_model(model_name)
  loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
  sentence_embeddings = []
  for batch in tqdm(loader, desc='Computing sentence embeddings..'):
      with torch.no_grad():
          emb = model.encode(batch, convert_to_numpy=True)
          sentence_embeddings.append(emb)
  sentence_embeddings = np.concatenate(sentence_embeddings, axis=0)
else:
  sentence_embeddings = torch.load(f'sentence_embedding_{model_name}.pt')
text_segments = [item for sublist in chapter_paragraphs.values() for item in sublist] # all paragraphs


In [None]:

# Save embeddings
model_safe = model_name.split('/')[1]
save_paragraph_embeddings(sentence_embeddings, path=f'sentence_embeddings_{model_safe}.pt')

In [None]:
# Fluid dynamics parameters
max_iterations = 1000  # Maximum number of propagation iterations
k_neighbors = 25  # Number of nearest neighbors for graph construction
similarity_threshold = 0.65  # Minimum similarity for edge creation
lambda_inject = 0.03  # Injection rate for fluid dynamics
viscosity = 0.015  # Fluid viscosity coefficient
edge_velocity_scale = 10  # Scaling factor for edge velocities
dt = 0.05  # Time step for numerical integration
min_iterations = 50  # Minimum iterations before convergence check
tolerance = 1e-6  # Convergence tolerance

# Initialize fluid dynamics model
flow = FluidDynamicTopicFlow(
    topics, 
    paragraphs, 
    sentence_embeddings=sentence_embeddings,
    k=k_neighbors, 
    similarity_threshold=similarity_threshold,
    lambda_inject=lambda_inject,
    viscosity=viscosity, 
    edge_velocity_scale=edge_velocity_scale,
    dt=dt, 
    min_iter=min_iterations, 
    tol=tolerance
)

# Run fluid propagation to compute final topic strengths
final_strength = flow.fluid_propagation(max_iter=max_iterations)


## Analysis of the output

In [None]:
print(flow.meta.summary_report())


report = flow.meta.summary_report(top_k=5)
print("\n=== SUMMARY REPORT ===")
for key, items in report.items():
    print(f"{key}: {items}")

latent = flow.meta.latent_modes(n_components=3)
print("\nLatent modes (first 3 paragraphs):\n", latent[:3])
visited = []
for key, items in report.items():
    for item in items:
      if item[0] not in visited:
        print(f"{item[0]}: {paragraphs[item[0]]}")
      visited.append(item[0])



### Heat map of the steady-state profile

In [None]:
def save_thematic_heatmap_data(
    file_path: str,
    *,
    text_segments: Sequence[str],
    strength_matrix: np.ndarray,
    topic_labels: Sequence[str],
    chapter_lengths: Sequence[int],
    meta: Mapping[str, Union[int, float, bool, str]] | None = None,
) -> None:

    topic_dict = {
        topic_labels[k]: [
            (text_segments[i], float(strength_matrix[i, k]))
            for i in range(len(text_segments))
        ]
        for k in range(len(topic_labels))
    }

    payload = {
        "topics": topic_dict
    }

    with open(file_path, "w", encoding="utf-8") as fp:
        json.dump(payload, fp, ensure_ascii=False, indent=2)

        fp.flush()
        os.fsync(fp.fileno())
    return topic_dict


def plot_thematic_heatmap(
    *,
    text_segments: Sequence[str],
    strength_matrix: np.ndarray,
    topics: Union[Mapping[str, str], Sequence[str]],
    chapter_lengths: Union[Mapping[int, int], Sequence[int]],
    window: int = 9,
    poly: int = 3,
    reorder_by_activity: bool = True,
    strength_cutoff: float = 0.0,
    figsize: tuple[int, int] = (26, 10),
    cmap_name: str = "viridis",
    colorbar_label: str = r"$S^\ast$",
    chapter_line_lw: float = 2.5,
    x_tick_step: int = 15,
    y_label_size: int = 40,
    x_label_size: int = 40,
    cb_label_size: int = 40,
    chapter_label_size: int = 29,
) -> None:

    n_par, n_top = strength_matrix.shape
    topic_labels = list(topics.keys()) if isinstance(topics, Mapping) else list(topics)
    assert len(topic_labels) == n_top
    assert len(text_segments) == n_par
    chap_lengths = (list(chapter_lengths.values())
                    if isinstance(chapter_lengths, Mapping) else list(chapter_lengths))
    assert sum(chap_lengths) >= n_par


    window = max(3, window + (window % 2 == 0))
    window = min(window, n_par if n_par % 2 else n_par - 1)
    poly = min(poly, window - 1)
    smooth_mat = savgol_filter(strength_matrix, window, poly, axis=0, mode="interp")

    keep = smooth_mat.max(axis=0) > strength_cutoff
    smooth_mat = smooth_mat[:, keep]
    topic_labels = [lbl for lbl, k in zip(topic_labels, keep) if k]

    if reorder_by_activity:
        order = np.argsort(smooth_mat.max(axis=0))[::-1]
        smooth_mat = smooth_mat[:, order]
        topic_labels = [topic_labels[i] for i in order]

    n_top_kept = smooth_mat.shape[1]

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(smooth_mat.T, aspect="auto",
                   cmap=cmap_name, interpolation="nearest", origin="lower")

    ax.set_xlabel("Paragraph index", fontsize=x_label_size+4, labelpad=12)
    ax.set_ylabel("Top-level concept", fontsize=x_label_size+4, labelpad=12)
    ax.set_yticks(np.arange(n_top_kept))
    ax.set_yticklabels(topic_labels, fontsize=y_label_size)
    ax.tick_params(axis="x", labelsize=y_label_size)
    ax.xaxis.set_major_locator(MultipleLocator(x_tick_step))

    boundaries = np.cumsum(chap_lengths)
    for b in boundaries[:-1]:
        ax.axvline(b - 0.5, color="white", linestyle="--", linewidth=chapter_line_lw)

    start = 0
    total = n_par
    for idx, length in enumerate(chap_lengths, start=1):
        end = start + length
        xpos_frac = (start + end) / 2 / total
        ax.text(xpos_frac, 1.02, f"Chapter {idx}",
                transform=ax.transAxes,
                ha="center", va="bottom",
                fontsize=chapter_label_size,
                color="black", weight="bold")
        start = end
        if start >= total:
            break

    cbar = fig.colorbar(im, ax=ax, pad=0.02)
    cbar.ax.set_ylabel(colorbar_label, fontsize=cb_label_size)
    cbar.ax.tick_params(labelsize=y_label_size)

    payload = save_thematic_heatmap_data(
    file_path="heatmap_data.json",
    text_segments=text_segments,
    strength_matrix=smooth_mat,
    topic_labels=topic_labels,
    chapter_lengths=chap_lengths,
    meta={
        "window": window,
        "poly": poly,
        "strength_cutoff": strength_cutoff,
        "reorder_by_activity": reorder_by_activity,
        "cmap": cmap_name
    }
)
    plt.tight_layout()
    plt.show()
    return payload

payload = plot_thematic_heatmap(
    text_segments=paragraphs,
    strength_matrix=final_strength.cpu().numpy(),
    topics=topics,
    chapter_lengths=chapter_length,
    window=9,
    poly=3,
    strength_cutoff=0.02,
    x_tick_step=15,
    chapter_line_lw=8.0,
    cmap_name="viridis_r"
)




### Authroial Filtering visualization

In [None]:

K_SIGMA_DIV      = 1.0
SHEAR_QUANTILE   = 0.90
DIVERG_QUANTILE  = 0.75
MIN_PARAGRAPHS   = 10
DF_THRESHOLD     = 2
CROSS_DF         = 0
REC_DF           = 10000
LINK_ALPHA_SRC   = 0.50
LINK_ALPHA_OTH   = 0.35

TOKEN_KEEP_ALIVE = True

LABEL_FONT_SIZE  = 44
LEGEND_FONT_SIZE = 18


def _norm(t: str) -> str:
    return re.sub(r"[^\w\s]", "", unicodedata.normalize("NFKD", t)
                  .encode("ascii", "ignore").decode()).lower()


_RE_META = re.compile(r"^(?:filename|chapter|page):\s+\S+", re.I)

def _clean_paragraph(raw: str) -> str:
    lines = [ln for ln in raw.splitlines() if not _RE_META.match(ln)]
    return unicodedata.normalize("NFKC", "\n".join(lines).strip())

topics = flow.topics
alias2cat, alias_keys, rgx_alias = {}, [], []
for cat, alias in topics.items():
    for a in {cat, *alias}:
        n = _norm(a)
        alias2cat[n] = cat
        alias_keys.append(n)
        rgx_alias.append((re.compile(r"\b"+re.escape(n)+r"\b", re.I), n))

def _alias(tok):
    n = _norm(tok)
    if n in alias2cat:
        return n
    m = difflib.get_close_matches(n, alias_keys, n=1, cutoff=.85)
    return m[0] if m else None

div_raw = flow.meta.divergence()
shr     = flow.meta.shear_index()
scale   = np.max(np.abs(div_raw)) or 1.0
div     = div_raw / scale

shear_thr = np.quantile(shr, SHEAR_QUANTILE)
abs_div   = np.abs(div)
nz_div    = abs_div[abs_div > 1e-9]
tau_div   = np.quantile(nz_div, DIVERG_QUANTILE) if nz_div.size else 0.
eps       = 1e-6

mixer = {i for i in range(len(div)) if abs_div[i] <= tau_div + eps and shr[i] > shear_thr}
src   = {i for i in range(len(div)) if div[i] < -tau_div}
sink  = {i for i in range(len(div)) if div[i] >  tau_div}

key_ids = set(src | sink | mixer)
if len(key_ids) < MIN_PARAGRAPHS:
    extra = sorted(np.argsort(abs_div)[:MIN_PARAGRAPHS])
    key_ids.update(extra)

chapter2pids = {}
start = 0
for chap, length in sorted(chapter_length.items()):
    chapter2pids[chap] = list(range(start, start + length))
    start += length

for chap, pids in chapter2pids.items():
    selected = key_ids.intersection(pids)
    deficit  = MIN_PARAGRAPHS - len(selected)
    if deficit > 0:
        remaining = [pid for pid in pids if pid not in key_ids]
        remaining_sorted = sorted(remaining, key=lambda x: abs_div[x])
        key_ids.update(remaining_sorted[:deficit])

key_ids = sorted(key_ids)

paragraphs = flow.paragraphs

PARAM_COL = {"time":"#b39ddb", "space":"#9575cd",
             "character":"#7e57c2", "language":"#5e35b1"}

inst_map, class_map = defaultdict(set), {}
cat_map              = defaultdict(set)

for pid in key_ids:
    raw_par       = paragraphs[pid]
    clean_par     = _clean_paragraph(raw_par)
    info          = param_extractor(clean_par, df_threshold=DF_THRESHOLD)

    for lbl, key in [("time", "time"), ("space", "space"),
                     ("character", "character")]:
        for tok in info.get(lbl, []):
            inst = tok.strip()
            inst_map[pid].add(inst)
            class_map[inst] = key

    txt_norm = _norm(clean_par)
    for inst in inst_map[pid]:
        if (al := _alias(inst)):
            cat_map[pid].add(alias2cat[al])

    for rgx, al in rgx_alias:
        if rgx.search(txt_norm):
            cat_map[pid].add(alias2cat[al])

all_params = {"time", "space", "character", "language"}
all_cats   = set(topics)

def sat(pid):
    subset = [p for p in key_ids if p <= pid]
    par = set().union(*({class_map[i] for i in inst_map[p]} for p in subset))
    cat = set().union(*(cat_map[p] for p in subset))
    return len(par) / len(all_params), len(cat) / len(all_cats)

satP = [sat(p)[0] for p in key_ids]
satC = [sat(p)[1] for p in key_ids]


def _rgba(hex_color: str, alpha: float) -> str:
    hex_color = hex_color.lstrip("#")
    r, g, b = [int(hex_color[i:i+2], 16) for i in (0, 2, 4)]
    return f"rgba({r},{g},{b},{alpha})"

def _chapter_of(pid: int) -> int:
    cum = 0
    for chap, length in sorted(chapter_length.items()):
        cum += length
        if pid < cum:
            return chap
    return max(chapter_length)


labels, colors, idx = [], [], {}
def _add(lbl, col):
    if lbl not in idx:
        idx[lbl] = len(labels)
        labels.append(lbl)
        colors.append(col)
    return idx[lbl]

pid2idx = {}
for pid in key_ids:
    col   = "#2ca02c" if pid in src else "#d62728" if pid in sink else "#ff7f0e"
    chap  = _chapter_of(pid)
    SUB   = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉")
    lbl   = f"P{str(pid).translate(SUB)} (c.{chap})"
    pid2idx[pid] = _add(lbl, col)

for inst in sorted(class_map):
    _add(inst, PARAM_COL[class_map[inst]])
for cat in sorted(all_cats):
    _add(cat, "#1f77b4")

_add("∞ inst", "#7f7f7f")
infinity_inst = idx["∞ inst"]
_add("∞ cat",  "#b0b0b0")
infinity_cat  = idx["∞ cat"]

inst_cat_df = defaultdict(lambda: defaultdict(int))
for pid in key_ids:
    for inst in inst_map[pid]:
        for cat in cat_map[pid]:
            inst_cat_df[inst][cat] += 1


edge_w, edge_h, edge_c = defaultdict(float), {}, {}
def _edge(s, t, w, h, c):
    edge_w[(s, t)] += w
    edge_h[(s, t)]  = h
    edge_c[(s, t)]  = c

def _pclass(p):
    return ("SRC" if p in src else "SNK" if p in sink
            else "MXR" if p in mixer else "OTH")


for pid in key_ids:
    col_src = colors[pid2idx[pid]]
    for inst in inst_map[pid]:
        _edge(pid2idx[pid], idx[inst], 1,
              f"P{pid} → {inst}", _rgba(col_src, LINK_ALPHA_SRC))

for inst in class_map:
    ic         = PARAM_COL[class_map[inst]]
    pids       = [p for p in key_ids if inst in inst_map[p]]
    alias_key  = _alias(inst)
    first_snk  = min((p for p in pids if p in sink), default=None)

    for pid in pids:
        if first_snk and pid > first_snk:
            break
        for cat in cat_map[pid]:
            direct = (alias_key is not None and alias2cat[alias_key] == cat)
            cross  = inst_cat_df[inst][cat] >= CROSS_DF
            if direct or cross:
                _edge(idx[inst], idx[cat], 1,
                      f"{inst} → {cat}", _rgba(ic, LINK_ALPHA_OTH))

    if first_snk is None and (len(pids) >= REC_DF or alias_key is not None):
        _edge(idx[inst], infinity_inst, 1,
              f"{inst} → ∞", _rgba(ic, LINK_ALPHA_OTH * 0.7))

for cat in all_cats:
    last_arc = max((p for p in key_ids
                    if cat in cat_map[p] and inst_map[p]), default=None)
    if last_arc is None or _pclass(last_arc) == "SNK":
        continue
    if TOKEN_KEEP_ALIVE:
        last_tok = max((p for p in key_ids if cat in cat_map[p]), default=None)
        if last_tok and _pclass(last_tok) == "SNK":
            continue
    cnt = sum(1 for pid in key_ids if cat in cat_map[pid])
    w   = math.sqrt(cnt)
    _edge(idx[cat], infinity_cat, w,
          f"{cat} → ∞ cat ({cnt})", _rgba("#1f77b4", LINK_ALPHA_OTH * 0.5))

src_i, tgt_i, val, htxt, lcol = [], [], [], [], []
for (s, t), w in edge_w.items():
    src_i.append(s); tgt_i.append(t); val.append(w)
    htxt.append(edge_h[(s, t)] + "<extra></extra>")
    lcol.append(edge_c[(s, t)])

idx2pid = {v: k for k, v in pid2idx.items()}
node_hover = []
for i, lbl in enumerate(labels):
    if i in idx2pid:
        pid = idx2pid[i]
        node_hover.append(
            f"{lbl}<br>div={div[pid]:+.3f}<br>shear={shr[pid]:.3f}<extra></extra>"
        )
    else:
        node_hover.append(f"{lbl}<extra></extra>")

fig = go.Figure(
    data=[go.Sankey(
        node=dict(
            label=labels,
            color=colors,
            pad=10,
            thickness=14,
            hovertemplate=node_hover,
        ),
        link=dict(
            source=src_i,
            target=tgt_i,
            value=val,
            color=lcol,
            line=dict(color=lcol),
            hovertemplate=htxt,
        ),
    )]
)

fig.update_layout(
    title=None,
    showlegend=False,
    font=dict(size=LABEL_FONT_SIZE),
    height=800,
    width=1200
)


sankey_name = globals().get("sankey_name", "authorial_sankey")
fig.update_layout(margin=dict(t=30, l=30, b=30, r=30))
fig.show()
fig.write_html(f"{sankey_name}.html", include_plotlyjs="cdn")
print(f" Sankey diagram written in {sankey_name}.html")
json_path = f"{sankey_name}.json"
sankey_to_json(labels, src_i, tgt_i, val, colors, file=json_path)
flow.authorial_dashboard = fig
