In [11]:
# ghw280_analyzer.py
# GHW280 EEG Channel Analyzer (Class-Based)
# Preserves your original code, aesthetics, and uses your expert electrode mapping

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
import os


class GHW280Analyzer:
    """
    EEG Channel Analyzer for the 280-channel Geodesic Head Web (GHW280)
    - Loads electrode coordinates
    - Assigns anatomical regions and functional roles
    - Visualizes 3D and top-down layouts
    - Uses YOUR electrode definitions for blinks and facial muscles
    """

    def __init__(self, filepath):
        self.filepath = filepath
        self.channels = None        # List of (name, x, y, z)
        self.ch_dict = None         # Dict: name -> np.array([x,y,z])
        self.mapping = None         # Dict: original -> renamed (e.g., '274' -> 'hEOG_L')
        self.regions = None         # Dict: region -> [channel names]
        self.fiducials = {}

        # Your expert functional definitions
        self.blinks = ['274', '41']           # Left and right blink detectors
        self.facial_muscles = ['280', '52']   # Left and right facial EMG
        self.ground = '32'                    # Ground and vertical EOG

    def parse_gpsc(self):
        """Parse GHW280 .txt/.gpsc file: Label X Y Z"""
        if not os.path.exists(self.filepath):
            raise FileNotFoundError(f"File not found: {self.filepath}")

        channels = []
        with open(self.filepath, 'r') as file:
            for line in file:
                parts = line.strip().split()
                if len(parts) < 4:
                    continue
                name = parts[0]
                try:
                    x, y, z = map(float, parts[1:4])
                    channels.append((name, x, y, z))
                except ValueError:
                    print(f"Warning: Could not parse {name}")
                    continue
        print(f"Parsed {len(channels)} channels from {self.filepath}")
        self.channels = channels
        self.ch_dict = {ch[0]: np.array(ch[1:]) for ch in channels}
        return self

    def assign_regions_and_rename(self):
        """Assign anatomical regions and rename channels using your expert labels"""
        if self.channels is None:
            raise RuntimeError("Call parse_gpsc() first.")

        # Extract fiducials
        for fid_name in ['FidNz', 'FidT9', 'FidT10', 'Cz']:
            if fid_name in self.ch_dict:
                self.fiducials[fid_name] = self.ch_dict[fid_name]

        if len(self.fiducials) < 3:
            print("Warning: Not enough fiducials found, using approximate values")
            fnz = np.array([100, 200, 150])
            cz = np.array([100, 100, 250])
        else:
            fnz = self.fiducials.get('FidNz', np.array([100, 200, 150]))
            cz = self.fiducials.get('Cz', np.array([100, 100, 250]))

        # Compute head center
        head_center = np.mean([pos for pos in self.ch_dict.values()], axis=0)

        # Initialize
        self.regions = {
            'Frontal': [], 'Central': [], 'Parietal': [], 'Occipital': [],
            'Temporal': [], 'Left': [], 'Right': [], 'Midline': [],
            'Mastoid': [], 'Periocular': [], 'Vertex': [], 'Other': []
        }
        self.mapping = {}

        # --- YOUR EXPERT MAPPING ---
        blink_left, blink_right = self.blinks
        emg_left, emg_right = self.facial_muscles
        ground = self.ground

        # Assign your custom roles
        if blink_left in self.ch_dict:
            self.mapping[blink_left] = 'hEOG_L'
            self.regions['Periocular'].append(blink_left)
        if blink_right in self.ch_dict:
            self.mapping[blink_right] = 'hEOG_R'
            self.regions['Periocular'].append(blink_right)
        if emg_left in self.ch_dict:
            self.mapping[emg_left] = 'EMG_L'
            self.regions['Periocular'].append(emg_left)
        if emg_right in self.ch_dict:
            self.mapping[emg_right] = 'EMG_R'
            self.regions['Periocular'].append(emg_right)
        if ground in self.ch_dict:
            self.mapping[ground] = 'VEOG'
            self.regions['Periocular'].append(ground)

        # --- REST OF AUTOMATIC MAPPING ---
        for name, pos in self.ch_dict.items():
            if name in self.mapping:
                continue

            x, y, z = pos
            rel_x = x - head_center[0]
            is_left = rel_x < -15
            is_right = rel_x > 15
            is_midline = abs(rel_x) <= 15

            rel_y = y - head_center[1]
            is_frontal = rel_y > 30
            is_central = -30 <= rel_y <= 30
            is_parietal = -60 <= rel_y < -30
            is_occipital = rel_y < -60

            rel_z = z - head_center[2]
            is_temporal = rel_z < -20 and abs(rel_x) > 40

            # Fiducials
            if name.startswith('Fid'):
                if 'Nz' in name:
                    self.mapping[name] = 'Nz'
                elif 'T9' in name:
                    self.mapping[name] = 'A1'
                elif 'T10' in name:
                    self.mapping[name] = 'A2'
                self.regions['Other'].append(name)
                continue

            # Mastoid
            if name in ['FidT9', 'FidT10']:
                self.regions['Mastoid'].append(name)
                continue
            if 'FidT9' in self.fiducials and np.linalg.norm(pos - self.fiducials['FidT9']) < 25:
                self.regions['Mastoid'].append(name)
                self.mapping[name] = 'M1'
                continue
            elif 'FidT10' in self.fiducials and np.linalg.norm(pos - self.fiducials['FidT10']) < 25:
                self.regions['Mastoid'].append(name)
                self.mapping[name] = 'M2'
                continue

            # Vertex
            if name == 'Cz' or (is_midline and is_central and rel_z > 20):
                self.regions['Vertex'].append(name)
                self.mapping[name] = 'Cz'
                continue

            # Main regions
            if is_temporal:
                self.regions['Temporal'].append(name)
                self.mapping[name] = 'T7' if is_left and 'T7' not in self.mapping.values() else \
                                    'T8' if is_right and 'T8' not in self.mapping.values() else \
                                    'TP7' if is_left else 'TP8'
            elif is_frontal:
                self.regions['Frontal'].append(name)
                if is_midline:
                    self.mapping[name] = 'Fz' if 'Fz' not in self.mapping.values() else f'F{name[-1:]}'
                elif is_left:
                    self.mapping[name] = 'F3' if 'F3' not in self.mapping.values() else f'F{name[-1:]}'
                else:
                    self.mapping[name] = 'F4' if 'F4' not in self.mapping.values() else f'F{name[-1:]}'
            elif is_central:
                self.regions['Central'].append(name)
                if is_left:
                    self.mapping[name] = 'C3' if 'C3' not in self.mapping.values() else f'C{name[-1:]}'
                else:
                    self.mapping[name] = 'C4' if 'C4' not in self.mapping.values() else f'C{name[-1:]}'
            elif is_parietal:
                self.regions['Parietal'].append(name)
                if is_midline:
                    self.mapping[name] = 'Pz' if 'Pz' not in self.mapping.values() else f'P{name[-1:]}'
                elif is_left:
                    self.mapping[name] = 'P3' if 'P3' not in self.mapping.values() else f'P{name[-1:]}'
                else:
                    self.mapping[name] = 'P4' if 'P4' not in self.mapping.values() else f'P{name[-1:]}'
            elif is_occipital:
                self.regions['Occipital'].append(name)
                if is_midline:
                    self.mapping[name] = 'Oz' if 'Oz' not in self.mapping.values() else f'O{name[-1:]}'
                elif is_left:
                    self.mapping[name] = 'O1' if 'O1' not in self.mapping.values() else f'O{name[-1:]}'
                else:
                    self.mapping[name] = 'O2' if 'O2' not in self.mapping.values() else f'O{name[-1:]}'
            else:
                self.regions['Other'].append(name)
                self.mapping[name] = name

            # Laterality
            if is_left:
                self.regions['Left'].append(name)
            elif is_right:
                self.regions['Right'].append(name)
            else:
                self.regions['Midline'].append(name)

        return self

    def plot_3d_enhanced(self):
        """Plot 3D interactive visualization with modern aesthetics"""
        fig = go.Figure()
        colors = {
            'Frontal': '#FF6B6B',
            'Central': '#4ECDC4',
            'Parietal': '#45B7D1',
            'Occipital': '#96CEB4',
            'Temporal': '#FECA57',
            'Mastoid': '#6C5CE7',
            'Periocular': '#FD79A8',
            'Vertex': '#2D3436',
            'Other': '#DDD'
        }

        for region, color in colors.items():
            ch_names = self.regions.get(region, [])
            if not ch_names:
                continue
            data = [ch for ch in self.channels if ch[0] in ch_names]
            x = [c[1] for c in data]
            y = [c[2] for c in data]
            z = [c[3] for c in data]
            labels = [self.mapping.get(c[0], c[0]) for c in data]

            marker_size = 8 if region in ['Vertex', 'Periocular', 'Mastoid'] else 6
            marker_symbol = 'diamond' if region == 'Vertex' else 'circle'

            fig.add_trace(go.Scatter3d(
                x=x, y=y, z=z,
                mode='markers+text',
                text=labels,
                textposition='top center',
                textfont=dict(size=10, color='white'),
                marker=dict(size=marker_size, color=color, opacity=0.9, symbol=marker_symbol),
                name=f"{region} ({len(data)})",
                hovertemplate='<b>%{text}</b><br>X: %{x:.1f}<br>Y: %{y:.1f}<br>Z: %{z:.1f}<extra></extra>'
            ))

        fig.update_layout(
            scene=dict(
                xaxis=dict(title='X (mm)', backgroundcolor='rgb(20,20,20)', gridcolor='rgb(80,80,80)', showbackground=True),
                yaxis=dict(title='Y (mm)', backgroundcolor='rgb(20,20,20)', gridcolor='rgb(80,80,80)', showbackground=True),
                zaxis=dict(title='Z (mm)', backgroundcolor='rgb(20,20,20)', gridcolor='rgb(80,80,80)', showbackground=True),
                bgcolor='rgb(10,10,10)',
                camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
            ),
            title="GHW280 - EEG Channel Mapping",
            paper_bgcolor='rgb(10,10,10)',
            font=dict(color='white'),
            width=1200,
            height=900,
            legend=dict(bgcolor="rgba(0,0,0,0.5)", bordercolor="white")
        )
        fig.show()
        return self

    def plot_topdown_enhanced(self):
        """Plot top-down 2D view"""
        plt.style.use('dark_background')
        fig, ax = plt.subplots(figsize=(16, 14))
        colors = {
            'Frontal': '#FF6B6B',
            'Central': '#4ECDC4',
            'Parietal': '#45B7D1',
            'Occipital': '#96CEB4',
            'Temporal': '#FECA57',
            'Mastoid': '#6C5CE7',
            'Periocular': '#FD79A8',
            'Vertex': '#FFFFFF',
            'Other': '#CCCCCC'
        }

        for region, color in colors.items():
            ch_names = self.regions.get(region, [])
            if not ch_names:
                continue
            data = [ch for ch in self.channels if ch[0] in ch_names]
            x = [c[1] for c in data]
            y = [c[2] for c in data]
            labels = [self.mapping.get(c[0], c[0]) for c in data]

            marker_size = 120 if region in ['Vertex', 'Periocular'] else 80
            marker = 's' if region == 'Vertex' else 'o'

            ax.scatter(x, y, s=marker_size, c=color, marker=marker, alpha=0.8,
                       edgecolors='white', linewidths=1.5, label=f'{region} ({len(data)})')

            for xi, yi, lbl in zip(x, y, labels):
                ax.text(xi, yi, lbl, fontsize=8, ha='center', va='center', weight='bold', color='black',
                        bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8, edgecolor='none'))

        ax.set_facecolor('#1a1a1a')
        ax.invert_yaxis()
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)
        ax.set_title('Top-Down View: Enhanced EEG Channel Layout', fontsize=18, color='white', pad=20)
        ax.set_xlabel('Y (mm)', fontsize=14, color='white')
        ax.set_ylabel('X (mm)', fontsize=14, color='white')
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fancybox=True, shadow=True)
        plt.tight_layout()
        plt.show()
        return self

    def print_stats(self):
        """Print enhanced statistics"""
        print(f"\n📊 TOTAL CHANNELS: {len(self.channels)}")
        print(f"📍 SUCCESSFULLY MAPPED: {len([m for m in self.mapping.values() if m])}")

        print("\n🎯 ANATOMICAL REGION DISTRIBUTION:")
        print("-" * 50)
        for region in ['Frontal', 'Central', 'Parietal', 'Occipital', 'Temporal',
                       'Vertex', 'Periocular', 'Mastoid', 'Other']:
            chs = self.regions.get(region, [])
            if chs:
                percentage = (len(chs) / len(self.channels)) * 100
                print(f"{region:15} : {len(chs):3d} channels ({percentage:5.1f}%)")

        print("\n👁️ YOUR CUSTOM CHANNELS:")
        print("-" * 30)
        print(f"  Blinks:      E{self.blinks[0]} (left), E{self.blinks[1]} (right)")
        print(f"  Facial EMG:  E{self.facial_muscles[0]} (left), E{self.facial_muscles[1]} (right)")
        print(f"  Ground/VEOG: E{self.ground}")
        return self


# -------------------------------
# USAGE EXAMPLE
# -------------------------------
if __name__ == "__main__":
    # UPDATE THIS PATH TO YOUR FILE
    file_path = "/home/jaizor/jaizor/xtra/data/ghw280_from_egig.gpsc"

    analyzer = GHW280Analyzer(file_path) \
        .parse_gpsc() \
        .assign_regions_and_rename() \
        .print_stats() \
        .plot_3d_enhanced()
        # .plot_topdown_enhanced()  # Optional

    # Access results
    print("\n✅ Your custom channels are ready for ICA:")
    print(f"blinks = {analyzer.blinks}")
    print(f"facial_muscles = {analyzer.facial_muscles}")

Parsed 284 channels from /home/jaizor/jaizor/xtra/data/ghw280_from_egig.gpsc

📊 TOTAL CHANNELS: 284
📍 SUCCESSFULLY MAPPED: 284

🎯 ANATOMICAL REGION DISTRIBUTION:
--------------------------------------------------
Frontal         :  74 channels ( 26.1%)
Central         :  48 channels ( 16.9%)
Parietal        :  29 channels ( 10.2%)
Occipital       :  55 channels ( 19.4%)
Temporal        :  65 channels ( 22.9%)
Vertex          :   4 channels (  1.4%)
Mastoid         :   6 channels (  2.1%)
Other           :   3 channels (  1.1%)

👁️ YOUR CUSTOM CHANNELS:
------------------------------
  Blinks:      E274 (left), E41 (right)
  Facial EMG:  E280 (left), E52 (right)
  Ground/VEOG: E32



✅ Your custom channels are ready for ICA:
blinks = ['274', '41']
facial_muscles = ['280', '52']
