In [1]:
from cereprocess.datasets.pipeline import general_pipeline, neurotransformer_pipeline, resample
from cereprocess.datasets.channels import NEUROTRANSFORMER_CHANNELS
from models.neurogate import NeuroGate
from models.neurotransformer import Neurotransformer
import torch
import torch.nn.functional as F
import os
import mne
import numpy as np
import ollama
import gc
from pdr import PDREstimator
import warnings

warnings.filterwarnings("ignore")

CHANNEL_REGIONS = {
    "Frontal": ["FP1", "FP2", "F3", "F4", "FZ"],
    "Left Temporal": ["F7", "T3", "T5"],
    "Right Temporal": ["F8", "T4", "T6"],
    "Central": ["C3", "C4", "CZ"],
    "Parietal": ["P3", "P4", "PZ"],
    "Occipital": ["O1", "O2"]
}

# These models must be pulled using ollama
MODELS = [
    "llama3.1",
    "qwen3:8b", 
    "gemma3:12b", 
    "mistral-nemo:12b", 
    "deepseek-r1:8b",
    "glm4:9b",
    "koesn/llama3-openbiollm-8b", 
    "meditron:7b", 
    ]

# The .edf files must have same name as .txt files in reports
# The files must be separated into two subfolders: normal and abnormal
EDF_ROOT = "path/to/edf/files/"
GROUND_TRUTH_DIR = "reports_extracted_txt" 
REPORTS = f"reports_gen"

In [None]:
def ensure_dir(dir):
    os.makedirs(dir, exist_ok=True)
    os.makedirs(os.path.join(dir, 'normal'), exist_ok=True)
    os.makedirs(os.path.join(dir, 'abnormal'), exist_ok=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Load models
neurogate = NeuroGate().to(device)
neurotransformer = Neurotransformer().to(device)

In [None]:
# Load weights
neurogate.load_state_dict(torch.load('model_weights/neurogate_wgts.pt'))
neurogate.eval()
neurotransformer.load_state_dict(torch.load('model_weights/neurotransformer_wgts.pth'))
neurotransformer.eval();

In [None]:
# Define Pipelines
neurogate_pl = general_pipeline('NMT')
neurotransformer_pl = neurotransformer_pipeline('NMT')
pdr_pl = resample()

In [None]:
# Define PDREstimator
o1_idx = NEUROTRANSFORMER_CHANNELS.index("O1")
o2_idx = NEUROTRANSFORMER_CHANNELS.index("O2")
estimator = PDREstimator(200, o1_idx, o2_idx, 0)

In [None]:
def get_ab_prob(mne_data):
    # Pass through neurogate
    processed_data = neurogate_pl.apply(mne_data)
    data = processed_data.get_data()
    data = data[None, :, :]  
    data = torch.from_numpy(data).float().to(device)

    with torch.no_grad():
        outputs = neurogate(data)

    ab_prob = list(F.softmax(outputs).cpu().numpy().reshape(-1,))[1] * 100
    return ab_prob

In [None]:
def get_clinical_adjective(percentage, threshold):
    """
    Returns clinical quantification terms based on ACNS 2021 Guidelines.
    Reference: Hirsch LJ, et al. J Clin Neurophysiol. 2021.
    """
    if percentage < 1.0:
        return "Rare"
    elif percentage < threshold + 10.0:
        return "Occasional"  # Shifted up (was 1-10%)
    elif percentage < 50.0:
        return "Frequent"    # (15-49%)
    elif percentage < 90.0:
        return "Abundant"    # (50-89%)
    else:
        return "Continuous"  # (>= 90%)



In [None]:
def get_region_report(mne_data, threshold):
    all_events = np.array(["normal wave", "spike and sharp wave", "slow wave"])

    processed_data = neurotransformer_pl.apply(mne_data)
    data = processed_data.get_data()

    result_events = {}

    for i, ch_name in enumerate(NEUROTRANSFORMER_CHANNELS):
        ch_data = data[:, i:i+1, :]
        ch_data = torch.from_numpy(ch_data).float().to(device)
        outputs=None
        with torch.no_grad():
            outputs = neurotransformer(ch_data)
        outputs = outputs.cpu().numpy()
        outputs = outputs.argmax(axis=1)
        events = list(all_events[outputs])
        result_events[ch_name] = events

    region_report = {}

    for region_name, channels in CHANNEL_REGIONS.items():
        total_windows = 0
        spike_count = 0
        slow_count = 0
        
        for ch in channels:
            events = result_events[ch]
            total_windows += len(events)
            spike_count += events.count("spike and sharp wave")
            slow_count += events.count("slow wave")

        # Calculate Percentages
        spike_pct = (spike_count / total_windows) * 100
        slow_pct = (slow_count / total_windows) * 100
        
        # Generate Clinical Descriptors
        findings = []

        if spike_pct >= threshold: # Threshold to report it
            adj = get_clinical_adjective(spike_pct, threshold)
            findings.append(f"{adj} epileptiform discharges")
            
        if slow_pct >= threshold:
            adj = get_clinical_adjective(slow_pct, threshold)
            findings.append(f"{adj} slowing")

        if not findings:
            description = "Normal activity."
        else:
            description = ", ".join(findings) + "."

        region_report[region_name] = {
            "description": description,
            "stats": {
                "spike_pct": spike_pct,
                "slow_pct": slow_pct
            }
        }
    return region_report

In [None]:
def get_pdr(mne_data):
    # Get PDR
    processed_data = pdr_pl.apply(mne_data)
    data = processed_data.get_data()
    pdr_res = estimator.fit(data)

    # Format PDR for the report
    if pdr_res['pdr_o1'] and pdr_res['pdr_o2']:
        avg_pdr = (pdr_res['pdr_o1'] + pdr_res['pdr_o2']) / 2
        pdr_text = f"{avg_pdr:.1f} Hz"
    elif pdr_res['pdr_o1'] or pdr_res['pdr_o2']:
        val = pdr_res['pdr_o1'] or pdr_res['pdr_o2']
        pdr_text = f"{val:.1f} Hz"
    else:
        pdr_text = "Not well-formed"

    return pdr_text

In [None]:
# Run inference on each file for each model
sub_folders = ['normal', 'abnormal']

for model in MODELS:
    for folder in sub_folders:
        dir = REPORTS + f"_{model}"
        ensure_dir(dir)
        folder_path = os.path.join(EDF_ROOT, folder)

        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            
            if filename.replace('.edf', '.txt') in os.listdir(os.path.join(dir, folder)):
                print(f"Report already exists for {filename}")
                continue

            try:
                mne_data = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
            except Exception as e:
                print(e)
                continue

            ab_prob = get_ab_prob(mne_data)
            # print(ab_prob)

            region_report = get_region_report(mne_data, 25)
            # for region in CHANNEL_REGIONS.keys():
                # print(region_report[region])

            pdr_text = get_pdr(mne_data)
            # print(pdr_text)

            # Construct the Prompt
            prompt_content = f"""
            PATIENT STATISTICS:
            - Global Abnormality Probability: {ab_prob:.1f}% (If >50%, consider Abnormal)
            - Posterior Dominant Rhythm: {pdr_text}
            
            REGIONAL ANALYSIS:
            """
            for region, desc in region_report.items():
                prompt_content += f"- {region}: {desc}\n"

            system_instruction = """
            You are a clinical neurologist. Write a standard EEG report based strictly on the provided statistics.
            
            Format Requirements:
            1. Use exactly two sections: "FACTUAL REPORT" and "IMPRESSION". It should not have any other sections.
            2. In FACTUAL REPORT, describe the background (PDR) first, then regional findings.
            3. In IMPRESSION, state "Normal EEG" or "Abnormal EEG" followed by a summary sentence.
            4. Absolutely do not mention percentages in the final text; use clinical terms (Frequent, Occasional).
            5. Write each section as a paragraph.
            6. Do not use points.
            7. Do not describe each and every region separately.
            8. Do not assume any information on your own.

            EXAMPLES:
            Normal Report:
            FACTUAL REPORT: Background rhythm shows alpha waveform seen around 8 Hz which is appropriate for age. Photic stimulation and HV not performed due to the state of patient. Intermittent EMG artifacts were seen. Stage II sleep was not achieved.

            IMPRESSION: This EEG showed very mild encephalopathy and there is no element of non convulsive status. Kindly correlate with clinical picture.

            Abnormal Report:
            FACTUAL REPORT: Background rhythm during awake stage shows well-organized, welldeveloped, average voltage 10 hertz alpha activity in the posterior regions which is appropriate for age. It blocks with eye opening and it is bilaterally synchronous and symmetrical. Beta activity in the frontal or central areas is seen with average voltage and amplitude. Photic stimulation and hyperventilation was performed. Intermittent EMG artifacts were seen. Stage II sleep was not achieved.

            IMPRESSION: This is an abnormal EEG with asymmetrical features there is delta wave slowing from right hemisphere. There are some faster frequencies on left side also that could be due to breach rhythm .kindly correlate clinically
            """
            
            try:
                response = ollama.chat(model=model, messages=[
                    {'role': 'system', 'content': system_instruction},
                    {'role': 'user', 'content': prompt_content},
                ])
                
                report_text = response['message']['content']
                
                # Save to text file
                out_file = os.path.join(dir, folder, filename.replace('.edf', '.txt'))
                with open(out_file, "w") as f:
                    f.write(report_text)
                    
                print(f"Saved report to: {out_file}")
                
            except Exception as e:
                print(f"Ollama Error: {e}")

            if device == 'cuda':
                torch.cuda.empty_cache()
            gc.collect()


In [2]:
import os
import numpy as np
import nltk
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from bert_score import BERTScorer
from tqdm.notebook import tqdm
import gc
import torch

# Ensure NLTK resources are available
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('omw-1.4', quiet=True)

# Data loading function
def load_data_for_all_models(gt_dir):
    """
    Pre-loads all text data so we don't have to keep reading files.
    Returns a dictionary: { model_key: {'refs': [], 'cands': []} }
    """
    all_data = {}
    
    print("Loading text data for all models...")
    for model_key in MODELS:
        gen_dir = f"reports_gen_{model_key}"
        
        refs = []
        cands = []
        
        if not os.path.exists(gen_dir):
            print(f"Skipping {model_key} (Directory not found: {gen_dir})")
            continue
            
        for root, dirs, files in os.walk(gen_dir):
            for file in files:
                if not file.endswith(".txt"): continue
                
                gen_path = os.path.join(root, file)
                rel_path = os.path.relpath(gen_path, gen_dir)
                gt_path = os.path.join(gt_dir, rel_path)
                
                if os.path.exists(gt_path):
                    with open(gen_path, 'r', encoding='utf-8') as f:
                        cands.append(f.read().strip())
                    with open(gt_path, 'r', encoding='utf-8') as f:
                        refs.append(f.read().strip())
        
        all_data[model_key] = {'refs': refs, 'cands': cands, 'name': model_key}
        print(f"  -> {model_key}: Loaded {len(refs)} pairs.")
        
    return all_data

# Metric Calculations
def get_cpu_metrics(refs, cands):
    """Computes BLEU-4, ROUGE-L, and METEOR (No GPU needed)"""
    # Tokenize for BLEU/METEOR
    ref_tokens = [[nltk.word_tokenize(r)] for r in refs] 
    cand_tokens = [nltk.word_tokenize(c) for c in cands]
    
    # BLEU-4
    bleu = corpus_bleu(ref_tokens, cand_tokens, weights=(0.25, 0.25, 0.25, 0.25))
    
    # METEOR
    meteor_scores = [meteor_score(r, c) for r, c in zip(ref_tokens, cand_tokens)]
    meteor = np.mean(meteor_scores)
    
    # ROUGE-L
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge_l_f1s = [scorer.score(r, c)['rougeL'].fmeasure for r, c in zip(refs, cands)]
    rouge = np.mean(rouge_l_f1s)
    
    return bleu, rouge, meteor


# Load all data
model_data = load_data_for_all_models(GROUND_TRUTH_DIR)
results = {key: {} for key in model_data}

# 2. Compute CPU Metrics (BLEU, ROUGE, METEOR)
print("\n--- Computing Standard Metrics (BLEU, ROUGE, METEOR) ---")
for key, data in model_data.items():
    if not data['refs']: continue
    b, r, m = get_cpu_metrics(data['refs'], data['cands'])
    results[key]['BLEU'] = b
    results[key]['ROUGE'] = r
    results[key]['METEOR'] = m
    print(f"{data['name']}: BLEU={b:.2f}, ROUGE={r:.2f}, METEOR={m:.2f}")

# Compute BERTScore
# We load one BERT model, process ALL LLMs, then unload it.
bert_models = [
    ("distilbert-base-uncased", "DistilBERT", 64),
    ("roberta-large", "RoBERTa", 32),
    ("microsoft/deberta-xlarge-mnli", "DeBERTa", 2) 
]

# Safety limit: Truncate text to ~3000 chars
MAX_CHAR_LENGTH = 3000 

for model_path, short_name, batch_size in bert_models:
    print(f"\n--- Loading {short_name} (Batch: {batch_size}) ... ---")
    
    # Load model
    scorer = BERTScorer(model_type=model_path, lang="en", rescale_with_baseline=False, device='cuda')
    
    for key, data in model_data.items():
        if not data['refs']: continue
        
        # Truncate long strings 
        # Replace empty strings 
        clean_cands = []
        for c in data['cands']:
            c_trunc = c[:MAX_CHAR_LENGTH]
            # Placeholder for empty files
            if not c_trunc.strip():
                c_trunc = "." 
            clean_cands.append(c_trunc)
            
        clean_refs = [r[:MAX_CHAR_LENGTH] for r in data['refs']]

        print(f"   -> Processing {data['name']}...", end=" ", flush=True)
        
        # Pass the cleaned lists
        P, R, F1 = scorer.score(clean_cands, clean_refs, batch_size=batch_size, verbose=False)
        
        results[key][short_name] = F1.mean().item()
        print(f"F1: {F1.mean().item():.4f}")
        
        del P, R, F1, clean_cands, clean_refs
        torch.cuda.empty_cache()

    print(f"   -> Unloading {short_name}...")
    del scorer
    gc.collect()
    torch.cuda.empty_cache()
        
# Header formatting
header = f"{'Model Name':<30} {'BLEU':<10} {'ROUGE':<10} {'METEOR':<10} {'Distil':<10} {'RoBERTa':<10} {'DeBERTa':<10}"
print("\n" + "="*len(header))
print(header)
print("-" * len(header))

for key in MODELS:
    if key not in results or not results[key]:
        # Handle case where model wasn't found/processed
        print(f"{key:<30} -- Data not found --")
    else:
        r = results[key]
        name = model_data[key]['name']
        
        # Formatted row with fixed width spacing
        row_str = (f"{name:<30} "
                   f"{r.get('BLEU', 0):<10.4f} "
                   f"{r.get('ROUGE', 0):<10.4f} "
                   f"{r.get('METEOR', 0):<10.4f} "
                   f"{r.get('DistilBERT', 0):<10.4f} "
                   f"{r.get('RoBERTa', 0):<10.4f} "
                   f"{r.get('DeBERTa', 0):<10.4f}")
        print(row_str)

print("="*len(header) + "\n")

Loading text data for all models...
  -> llama3.1: Loaded 676 pairs.
  -> qwen3:8b: Loaded 676 pairs.
  -> gemma3:12b: Loaded 676 pairs.
  -> mistral-nemo:12b: Loaded 676 pairs.
  -> deepseek-r1:8b: Loaded 676 pairs.
  -> glm4:9b: Loaded 676 pairs.
  -> koesn/llama3-openbiollm-8b: Loaded 676 pairs.
  -> meditron:7b: Loaded 676 pairs.

--- Computing Standard Metrics (BLEU, ROUGE, METEOR) ---
llama3.1: BLEU=0.09, ROUGE=0.23, METEOR=0.26
qwen3:8b: BLEU=0.10, ROUGE=0.25, METEOR=0.27
gemma3:12b: BLEU=0.02, ROUGE=0.18, METEOR=0.19
mistral-nemo:12b: BLEU=0.09, ROUGE=0.25, METEOR=0.30
deepseek-r1:8b: BLEU=0.06, ROUGE=0.21, METEOR=0.23
glm4:9b: BLEU=0.03, ROUGE=0.17, METEOR=0.27
koesn/llama3-openbiollm-8b: BLEU=0.02, ROUGE=0.10, METEOR=0.10
meditron:7b: BLEU=0.00, ROUGE=0.07, METEOR=0.10

--- Loading DistilBERT (Batch: 64) ... ---
   -> Processing llama3.1... F1: 0.8099
   -> Processing qwen3:8b... F1: 0.8172
   -> Processing gemma3:12b... F1: 0.7870
   -> Processing mistral-nemo:12b... F1: 0.8

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


   -> Processing llama3.1... F1: 0.8667
   -> Processing qwen3:8b... F1: 0.8701
   -> Processing gemma3:12b... F1: 0.8499
   -> Processing mistral-nemo:12b... F1: 0.8641
   -> Processing deepseek-r1:8b... F1: 0.8653
   -> Processing glm4:9b... F1: 0.8503
   -> Processing koesn/llama3-openbiollm-8b... F1: 0.8270
   -> Processing meditron:7b... F1: 0.7606
   -> Unloading RoBERTa...

--- Loading DeBERTa (Batch: 2) ... ---
   -> Processing llama3.1... F1: 0.6263
   -> Processing qwen3:8b... F1: 0.6465
   -> Processing gemma3:12b... F1: 0.5964
   -> Processing mistral-nemo:12b... F1: 0.6386
   -> Processing deepseek-r1:8b... F1: 0.6228
   -> Processing glm4:9b... F1: 0.5917
   -> Processing koesn/llama3-openbiollm-8b... F1: 0.5109
   -> Processing meditron:7b... F1: 0.4297
   -> Unloading DeBERTa...

Model Name                     BLEU       ROUGE      METEOR     Distil     RoBERTa    DeBERTa   
------------------------------------------------------------------------------------------------