In [None]:
from cereprocess.datasets.pipeline import general_pipeline, neurotransformer_pipeline, resample
from cereprocess.datasets.channels import NEUROTRANSFORMER_CHANNELS
from models.neurogate import NeuroGate
from models.neurotransformer import Neurotransformer
import torch
import torch.nn.functional as F
import os
import mne
import numpy as np
import ollama
import gc
from pdr import PDREstimator

CHANNEL_REGIONS = {
    "Frontal": ["FP1", "FP2", "F3", "F4", "FZ"],
    "Left Temporal": ["F7", "T3", "T5"],
    "Right Temporal": ["F8", "T4", "T6"],
    "Central": ["C3", "C4", "CZ"],
    "Parietal": ["P3", "P4", "PZ"],
    "Occipital": ["O1", "O2"]
}

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Load models
neurogate = NeuroGate().to(device)
neurotransformer = Neurotransformer().to(device)

In [None]:
# Load weights
neurogate.load_state_dict(torch.load('model_weights/neurogate_wgts.pt'))
neurogate.eval()
neurotransformer.load_state_dict(torch.load('model_weights/neurotransformer_wgts.pth'))
neurotransformer.eval();

In [None]:
# Define Pipelines
neurogate_pl = general_pipeline('NMT')
neurotransformer_pl = neurotransformer_pipeline('NMT')
pdr_pl = resample()

In [None]:
# Define PDREstimator
o1_idx = NEUROTRANSFORMER_CHANNELS.index("O1")
o2_idx = NEUROTRANSFORMER_CHANNELS.index("O2")
estimator = PDREstimator(200, o1_idx, o2_idx, 0)

In [None]:
def get_clinical_adjective(percentage):
    """
    Returns clinical quantification terms based on ACNS 2021 Guidelines.
    Reference: Hirsch LJ, et al. J Clin Neurophysiol. 2021.
    """
    if percentage < 15.0:
        return "Occasional"  # Shifted up (was 1-10%)
    elif percentage < 50.0:
        return "Frequent"    # (15-49%)
    elif percentage < 90.0:
        return "Abundant"    # (50-89%)
    else:
        return "Continuous"  # (>= 90%)

In [None]:
def get_ab_prob(mne_data):
    # Pass through neurogate
    processed_data = neurogate_pl.apply(mne_data)
    data = processed_data.get_data()
    data = data[None, :, :]  
    data = torch.from_numpy(data).float().to(device)

    with torch.no_grad():
        outputs = neurogate(data)

    ab_prob = list(F.softmax(outputs).cpu().numpy().reshape(-1,))[1] * 100
    return ab_prob

In [None]:
def get_region_report(mne_data):
    all_events = np.array(["normal wave", "spike and sharp wave", "slow wave"])

    processed_data = neurotransformer_pl.apply(mne_data)
    data = processed_data.get_data()

    result_events = {}

    for i, ch_name in enumerate(NEUROTRANSFORMER_CHANNELS):
        ch_data = data[:, i:i+1, :]
        ch_data = torch.from_numpy(ch_data).float().to(device)
        outputs=None
        with torch.no_grad():
            outputs = neurotransformer(ch_data)
        outputs = outputs.cpu().numpy()
        outputs = outputs.argmax(axis=1)
        events = list(all_events[outputs])
        result_events[ch_name] = events

    region_report = {}

    for region_name, channels in CHANNEL_REGIONS.items():
        total_windows = 0
        spike_count = 0
        slow_count = 0
        
        for ch in channels:
            events = result_events[ch]
            total_windows += len(events)
            spike_count += events.count("spike and sharp wave")
            slow_count += events.count("slow wave")

        # Calculate Percentages
        spike_pct = (spike_count / total_windows) * 100
        slow_pct = (slow_count / total_windows) * 100
        
        # Generate Clinical Descriptors
        findings = []

        if spike_pct >= 5.0: # Threshold to report it
            adj = get_clinical_adjective(spike_pct)
            findings.append(f"{adj} epileptiform discharges")
            
        # 2. Analyze Slowing
        if slow_pct >= 5.0:
            adj = get_clinical_adjective(slow_pct)
            findings.append(f"{adj} slowing")

        if not findings:
            description = "Normal activity."
        else:
            description = ", ".join(findings) + "."

        region_report[region_name] = {
            "description": description,
            "stats": {
                "spike_pct": spike_pct,
                "slow_pct": slow_pct
            }
        }
    return region_report

In [None]:
def get_pdr(mne_data):
    # Get PDR
    processed_data = pdr_pl.apply(mne_data)
    data = processed_data.get_data()
    pdr_res = estimator.fit(data)

    # Format PDR for the report
    if pdr_res['pdr_o1'] and pdr_res['pdr_o2']:
        avg_pdr = (pdr_res['pdr_o1'] + pdr_res['pdr_o2']) / 2
        pdr_text = f"{avg_pdr:.1f} Hz"
    elif pdr_res['pdr_o1'] or pdr_res['pdr_o2']:
        val = pdr_res['pdr_o1'] or pdr_res['pdr_o2']
        pdr_text = f"{val:.1f} Hz"
    else:
        pdr_text = "Not well-formed"

    return pdr_text

In [None]:
# Run inference on each file
EDF_ROOT = "/media/tukl/ee279b7d-bb8a-4a20-8bf9-90b2c542efcc/EEG Datasets/nmt_4k_f/edf"
REPORTS = "reports_gen"
sub_folders = ['normal', 'abnormal']

for folder in sub_folders:
    folder_path = os.path.join(EDF_ROOT, folder)

    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        
        if filename.replace('.edf', '.txt') in os.listdir(os.path.join(REPORTS, folder)):
            print(f"Report already exists for {filename}")
            continue

        try:
            mne_data = mne.io.read_raw_edf(file_path, preload=True)
        except Exception as e:
            print(e)
            continue

        ab_prob = get_ab_prob(mne_data)
        print(ab_prob)

        region_report = get_region_report(mne_data)
        for region in CHANNEL_REGIONS.keys():
            print(region_report[region])

        pdr_text = get_pdr(mne_data)
        print(pdr_text)

        # Construct the Prompt
        prompt_content = f"""
        PATIENT STATISTICS:
        - Global Abnormality Probability: {ab_prob:.1f}% (If >50%, consider Abnormal)
        - Posterior Dominant Rhythm: {pdr_text}
        
        REGIONAL ANALYSIS:
        """
        for region, desc in region_report.items():
            prompt_content += f"- {region}: {desc}\n"

        system_instruction = """
        You are a clinical neurologist. Write a standard EEG report based strictly on the provided statistics.
        
        Format Requirements:
        1. Use exactly two sections: "FACTUAL REPORT" and "IMPRESSION".
        2. In FACTUAL REPORT, describe the background (PDR) first, then regional findings.
        3. In IMPRESSION, state "Normal EEG" or "Abnormal EEG" followed by a summary sentence.
        4. Absolutely do not mention percentages in the final text; use clinical terms (Frequent, Occasional).
        5. Write each section as a paragraph.
        6. Do not use points.
        7. Do not describe each and every region separately.
        8. Do not assume any information on your own.

        EXAMPLES:
        Normal Report:
        FACTUAL REPORT: Background rhythm shows alpha waveform seen around 8 Hz which is appropriate for age. Photic stimulation and HV not performed due to the state of patient. Intermittent EMG artifacts were seen. Stage II sleep was not achieved.

        IMPRESSION: This EEG showed very mild encephalopathy and there is no element of non convulsive status. Kindly correlate with clinical picture.

        Abnormal Report:
        FACTUAL REPORT: Background rhythm during awake stage shows well-organized, welldeveloped, average voltage 10 hertz alpha activity in the posterior regions which is appropriate for age. It blocks with eye opening and it is bilaterally synchronous and symmetrical. Beta activity in the frontal or central areas is seen with average voltage and amplitude. Photic stimulation and hyperventilation was performed. Intermittent EMG artifacts were seen. Stage II sleep was not achieved.

        IMPRESSION: This is an abnormal EEG with asymmetrical features there is delta wave slowing from right hemisphere. There are some faster frequencies on left side also that could be due to breach rhythm .kindly correlate clinically
        """
        
        try:
            response = ollama.chat(model='llama3.1', messages=[
                {'role': 'system', 'content': system_instruction},
                {'role': 'user', 'content': prompt_content},
            ])
            
            report_text = response['message']['content']
            
            # Save to text file
            out_file = os.path.join(REPORTS, folder, filename.replace('.edf', '.txt'))
            with open(out_file, "w") as f:
                f.write(report_text)
                
            print(f"Saved report to: {out_file}")
            
        except Exception as e:
            print(f"Ollama Error: {e}")

        if device == 'cuda':
            torch.cuda.empty_cache()
        gc.collect()


In [None]:
import os
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score
from tqdm import tqdm

# --- CONFIGURATION ---
GROUND_TRUTH_DIR = "reports_extracted_txt"  # Folder with real doctor reports
GENERATED_DIR = "reports_gen"               # Folder with AI generated reports

def load_file_pairs(gt_dir, gen_dir):
    """
    Matches files in Ground Truth and Generated directories by filename.
    Returns two lists: references (truth) and candidates (AI).
    """
    refs = []
    cands = []
    matched_files = 0

    print(f"Scanning directories...")
    
    # Walk through the Generated directory (since we only care about what we generated)
    for root, dirs, files in os.walk(gen_dir):
        for file in files:
            if not file.endswith(".txt"):
                continue
            
            # Path to the generated file
            gen_path = os.path.join(root, file)
            
            # Construct the corresponding Ground Truth path
            # Assuming structure matches (e.g., normal/file.txt)
            rel_path = os.path.relpath(gen_path, gen_dir)
            gt_path = os.path.join(gt_dir, rel_path)
            
            if os.path.exists(gt_path):
                with open(gen_path, 'r', encoding='utf-8') as f:
                    cands.append(f.read().strip())
                
                with open(gt_path, 'r', encoding='utf-8') as f:
                    refs.append(f.read().strip())
                
                matched_files += 1
            else:
                # Optional: Warn if GT is missing for a generated file
                # print(f"Warning: No Ground Truth found for {rel_path}")
                pass

    print(f"Found {matched_files} matching report pairs.")
    return refs, cands

def compute_metrics(references, candidates):
    if not references:
        print("No data to evaluate.")
        return

    print("\n--- Computing ROUGE-L ---")
    # Initialize ROUGE scorer (using 'rougeL' which accounts for sentence structure)
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    
    rouge_l_f1s = []
    rouge_l_precisions = []
    rouge_l_recalls = []

    for ref, cand in tqdm(zip(references, candidates), total=len(references)):
        scores = scorer.score(ref, cand)
        rouge_l_f1s.append(scores['rougeL'].fmeasure)
        rouge_l_precisions.append(scores['rougeL'].precision)
        rouge_l_recalls.append(scores['rougeL'].recall)

    print("\n--- Computing BERTScore (This may take a moment) ---")
    # BERTScore uses a pre-trained model to check semantic meaning
    # 'distilbert-base-uncased' is faster; use 'roberta-large' for higher accuracy if you have GPU memory
    P, R, F1 = score(candidates, references, lang="en", verbose=True, model_type="distilbert-base-uncased")

    # --- PRINT RESULTS ---
    print("\n" + "="*40)
    print("FINAL EVALUATION RESULTS")
    print("="*40)
    
    print(f"ROUGE-L F1:       {np.mean(rouge_l_f1s):.4f}")
    print(f"ROUGE-L Precision:{np.mean(rouge_l_precisions):.4f}")
    print(f"ROUGE-L Recall:   {np.mean(rouge_l_recalls):.4f}")
    print("-" * 40)
    print(f"BERTScore F1:     {F1.mean().item():.4f}")
    print(f"BERTScore Precision: {P.mean().item():.4f}")
    print(f"BERTScore Recall:    {R.mean().item():.4f}")
    print("="*40)

if __name__ == "__main__":
    # 1. Load Data
    ground_truth_texts, generated_texts = load_file_pairs(GROUND_TRUTH_DIR, GENERATED_DIR)
    
    # 2. Run Evaluation
    compute_metrics(ground_truth_texts, generated_texts)