In [15]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import re
import statistics
from scipy.interpolate import interp1d
from scipy import signal

In [16]:
class WaveformProcessor:
    """
    Base class for storing and processing waveform data from different boards
    """
    def __init__(self, name=None):
        self.name = name or "Unnamed"
        self.channels = {}  # Main data structure

    def get_available_channels(self):
        """
        Get the channels that have data stored in the dictionary
        """
        return sorted(list(self.channels.keys()))

    def get_channel_waveform(self, waveform_type='singles', channel=None, trace_index=0):
        """
        Get a specific waveform of a channel
        
        Args:
            waveform_type (str): 'singles' or 'averages'
            channel (int): Channel number
            trace_index (int): Index of the trace to retrieve
            
        Returns:
            pd.DataFrame: The requested waveform dataframe
        """
        if channel not in self.channels:
            raise ValueError(f"Channel {channel} not found")  

        if waveform_type not in self.channels[channel]:
            raise ValueError(f"Channel {channel} does not have {waveform_type} data")

        if trace_index not in self.channels[channel][waveform_type]:
            raise ValueError(f"Channel {channel} does not have trace {trace_index} in {waveform_type} data. There are {len(self.channels[channel][waveform_type].keys())} traces available.")

        return self.channels[channel][waveform_type][trace_index]

    def get_pedestal(self, data, baseline_start_pct=0, baseline_end_pct=0.2):
        """
        Calculate the pedestal (baseline) of a waveform within a specified range of the data
        """
        start_idx = int(len(data) * baseline_start_pct)
        end_idx = int(len(data) * baseline_end_pct)
        return np.mean(data[start_idx:end_idx+1])

    def getPeakIndex(self, data, baseline_start_pct=0, baseline_end_pct=0.2, threshold=0.1):
        """
        Get the index of the peak value in a waveform above a specified threshold
        """
        peak_value = 0
        peak_index = 0
        threshold_index = 0
        crossed_threshold = False
        counter = 0
        pedestal = self.get_pedestal(data, baseline_start_pct, baseline_end_pct)
        for i in range(len(data)):
            if not crossed_threshold and data[i]-pedestal>threshold:
                crossed_threshold = True
                threshold_index = i
            if crossed_threshold and data[i]>peak_value:
                peak_value = data[i]
                peak_index = i
            elif crossed_threshold and data[i]<=peak_value:
                counter += 1
            if crossed_threshold and counter>1:
                break
        return peak_index, threshold_index
    
    def calculate_rise_time(self, channel, waveform_type='singles', trace_index=0, baseline_start_pct=0, baseline_end_pct=0.2, threshold=0.01, low_pct=0.1, high_pct=0.9):
        """
        Calculate rise time for a given channel.
        """
        df = self.get_channel_waveform(waveform_type, channel, trace_index)
        time = df["time"].values * 1e9 # Convert to ns
        signal = df["output"].values * 1e3 # Convert to mV

        pedestal = self.get_pedestal(signal, baseline_start_pct, baseline_end_pct)
        peak_index,threshold_index = self.getPeakIndex(signal, baseline_start_pct, baseline_end_pct, threshold)
        amplitude = signal[peak_index] - pedestal
        low_threshold = amplitude * low_pct - pedestal 
        high_threshold = amplitude * high_pct - pedestal

        interp_func = interp1d(time, signal, kind='linear') 
        interp_time = np.linspace(time[0], time[-1], num=10000)
        interp_signal = interp_func(interp_time)
        signal_above_low = interp_signal >= low_threshold
        signal_above_high = interp_signal >= high_threshold
        if np.any(signal_above_low):
            t_low_index = np.argmax(signal_above_low)
            t_low = interp_time[t_low_index]
        else:
            t_low = np.nan

        if np.any(signal_above_high):
            t_high_index = np.argmax(signal_above_high)
            t_high = interp_time[t_high_index]
        else:
            t_high = np.nan

        rise_time = t_high - t_low

        if 'analysis' not in self.channels[channel]:
            self.channels[channel]['analysis'] = {}

        self.channels[channel]['analysis']['rise_time'] = {
            "value": rise_time,
            "source": waveform_type,
            "trace_index": trace_index,
            "t_low": t_low,
            "t_high": t_high,
            "low_threshold": low_threshold,
            "high_threshold": high_threshold,
            "amplitude": amplitude,
            "pedestal": pedestal
        }
        return rise_time, t_low, t_high
    
    def calculate_all_rise_times(self, waveform_type='singles', trace_index=0,
                               baseline_start_pct=0, baseline_end_pct=0.2, 
                               threshold=0.01, low_pct=0.1, high_pct=0.9):
        """
        Calculate rise times for all loaded channels.
        
        Args:
            waveform_type (str): 'singles' or 'averages'
            trace_index (int): Index of the trace to use
            baseline_start_pct (float): Start point for baseline calculation
            baseline_end_pct (float): End point for baseline calculation
            threshold (float): Threshold for peak detection
            low_pct (float): Lower threshold percentage (0-1)
            high_pct (float): Upper threshold percentage (0-1)
            
        Returns:
            dict: Dictionary of rise times by channel
        """
        results = {}
        for channel in self.channels:
            try:
                if waveform_type in self.channels[channel] and trace_index in self.channels[channel][waveform_type]:
                    rt, t_low, t_high = self.calculate_rise_time(
                        channel, waveform_type, trace_index,
                        baseline_start_pct, baseline_end_pct, 
                        threshold, low_pct, high_pct
                    )
                    results[channel] = rt
                    print(f"Channel {channel}: Rise time = {rt*1e9:.2f} ns")
                else:
                    print(f"Channel {channel}: No {waveform_type} data at index {trace_index}")
                    results[channel] = np.nan
            except Exception as e:
                print(f"Error processing channel {channel}: {e}")
                results[channel] = np.nan
                
        return results
    
    def calculate_delays(self, waveform_type='averages', trace_index=0, reference_channel=None):
        """
        Calculate delays relative to reference channel.
        
        Args:
            waveform_type (str): 'singles' or 'averages'
            trace_index (int): Index of the trace to use
            reference_channel (int, optional): Channel to use as reference. If None, uses channel with minimum rise time.
            
        Returns:
            dict: Dictionary of delays by channel in picoseconds
        """
        # First ensure we have rise time data
        channels_with_rt = {}
        for ch in self.channels:
            if ('analysis' in self.channels[ch] and 
                'rise_time' in self.channels[ch]['analysis']):
                channels_with_rt[ch] = self.channels[ch]['analysis']['rise_time']
        
        if not channels_with_rt:
            print("No rise time measurements found. Calculating...")
            self.calculate_all_rise_times(waveform_type, trace_index)
            
            # Check again
            channels_with_rt = {}
            for ch in self.channels:
                if ('analysis' in self.channels[ch] and 
                    'rise_time' in self.channels[ch]['analysis']):
                    channels_with_rt[ch] = self.channels[ch]['analysis']['rise_time']
            
            if not channels_with_rt:
                raise ValueError("Failed to calculate rise times. Cannot calculate delays.")
        
        # Find reference channel if not specified
        if reference_channel is None:
            min_rt = float('inf')
            for ch, rt_data in channels_with_rt.items():
                if not np.isnan(rt_data['value']) and rt_data['value'] < min_rt:
                    min_rt = rt_data['value']
                    reference_channel = ch
            
            if reference_channel is None:
                raise ValueError("No valid rise time found for any channel.")
        
        if reference_channel not in channels_with_rt:
            raise ValueError(f"Reference channel {reference_channel} has no rise time data")
        
        # Get reference t_low time
        ref_t_low = channels_with_rt[reference_channel]['t_low']
        
        # Calculate delays for all channels
        delays = {}
        for ch in self.channels:
            if ('analysis' in self.channels[ch] and 
                'rise_time' in self.channels[ch]['analysis'] and 
                not np.isnan(self.channels[ch]['analysis']['rise_time']['t_low'])):
                
                # Calculate delay in picoseconds
                ch_t_low = self.channels[ch]['analysis']['rise_time']['t_low']
                delay_ps = (ch_t_low - ref_t_low) * 1e12
                
                # Store result
                if 'analysis' not in self.channels[ch]:
                    self.channels[ch]['analysis'] = {}
                
                self.channels[ch]['analysis']['delay'] = {
                    'value': delay_ps,
                    'reference_channel': reference_channel,
                    'source': waveform_type,
                    'trace_index': trace_index
                }
                
                delays[ch] = delay_ps
            else:
                delays[ch] = np.nan
        
        return delays
    
    def calculate_gains(self, waveform_type='averages', trace_index=0):
        """
        Calculate gains for all channels by comparing input and output amplitudes.
        
        Args:
            waveform_type (str): 'singles' or 'averages'
            trace_index (int): Index of the trace to use
            
        Returns:
            dict: Dictionary of gains by channel
        """
        gains = {}
        
        for ch in self.channels:
            try:
                if (waveform_type in self.channels[ch] and 
                    trace_index in self.channels[ch][waveform_type]):
                    
                    df = self.get_channel_waveform(waveform_type, ch, trace_index)
                    
                    # Check if we have both input and output columns
                    if "input" in df.columns and ("output" in df.columns or "amplitude" in df.columns):
                        # Get input and output signals
                        input_signal = df["input"].values
                        output_signal = df["amplitude"].values if "amplitude" in df.columns else df["output"].values
                        
                        # Calculate pedestals
                        input_pedestal = self.get_pedestal(input_signal)
                        output_pedestal = self.get_pedestal(output_signal)
                        
                        # Adjust signals
                        input_adj = input_signal - input_pedestal
                        output_adj = output_signal - output_pedestal
                        
                        # Find peaks
                        input_peak = np.max(np.abs(input_adj))
                        output_peak = np.max(np.abs(output_adj))
                        
                        # Calculate gain (output/input)
                        if input_peak != 0:
                            gain = output_peak / input_peak
                        else:
                            gain = np.nan
                            
                        # Store in channel's analysis
                        if 'analysis' not in self.channels[ch]:
                            self.channels[ch]['analysis'] = {}
                            
                        self.channels[ch]['analysis']['gain'] = {
                            'value': gain,
                            'input_peak': input_peak,
                            'output_peak': output_peak,
                            'source': waveform_type,
                            'trace_index': trace_index
                        }
                        
                        gains[ch] = gain
                        print(f"Channel {ch}: Gain = {gain:.4f}")
                    else:
                        print(f"Channel {ch} missing input or output data")
                        gains[ch] = np.nan
                else:
                    print(f"No {waveform_type} data for channel {ch} at index {trace_index}")
                    gains[ch] = np.nan
            except Exception as e:
                print(f"Error calculating gain for channel {ch}: {e}")
                gains[ch] = np.nan
                
        return gains
    
    def plot_waveform(self, channel, waveform_type='singles', trace_index=0, show_thresholds=True):
        """
        Plot a waveform for a specific channel.
        
        Args:
            channel (int): Channel identifier
            waveform_type (str): 'singles' or 'averages'
            trace_index (int): Index of the trace to plot
            show_thresholds (bool): Whether to show threshold markers
            
        Returns:
            matplotlib.figure.Figure: The figure object
        """
        df = self.get_channel_waveform(waveform_type, channel, trace_index)
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Plot the waveform
        time_ns = df["time"].values * 1e9  # Convert to nanoseconds
        if "amplitude" in df.columns:
            signal = df["amplitude"].values
        elif "output" in df.columns:
            signal = df["output"].values
        else:
            raise ValueError(f"No amplitude or output column found in {waveform_type} data for channel {channel}")
        
        # Get pedestal if available or calculate it
        if ('analysis' in self.channels[channel] and 
            'rise_time' in self.channels[channel]['analysis']):
            pedestal = self.channels[channel]['analysis']['rise_time']['pedestal']
        else:
            pedestal = self.get_pedestal(signal)
            
        signal_adj = signal - pedestal
            
        ax.plot(time_ns, signal_adj, label=f"{self.name} Ch{channel} ({waveform_type})")
        
        # Show threshold markers if requested and available
        if (show_thresholds and 'analysis' in self.channels[channel] and 
            'rise_time' in self.channels[channel]['analysis']):
            
            rt_data = self.channels[channel]['analysis']['rise_time']
            t_low_ns = rt_data['t_low']
            t_high_ns = rt_data['t_high']
            low_threshold = rt_data['low_threshold'] - pedestal  # Adjust for plotting
            high_threshold = rt_data['high_threshold'] - pedestal
            
            # Add vertical lines at threshold crossings
            ax.axvline(x=t_low_ns, color='green', linestyle='--', 
                      label=f'tLow = {t_low_ns:.2f} ns')
            ax.axvline(x=t_high_ns, color='red', linestyle='--', 
                      label=f'tHigh = {t_high_ns:.2f} ns')
            
            # Add horizontal lines at threshold levels
            ax.axhline(y=low_threshold, color='green', linestyle=':')
            ax.axhline(y=high_threshold, color='red', linestyle=':')
            
            # Add rise time text
            rise_time_ns = rt_data['value']
            ax.text(0.7, 0.9, f"Rise Time: {rise_time_ns:.2f} ns", 
                   transform=ax.transAxes, fontsize=12, 
                   bbox=dict(facecolor='white', alpha=0.7))
            
        ax.set_xlabel("Time (ns)")
        ax.set_ylabel("Amplitude")
        ax.set_title(f"{self.name} Channel {channel} Waveform ({waveform_type})")
        ax.grid(True)
        ax.legend()
        
        plt.tight_layout()
        return fig
    
    def plot_delays(self, highlight_extremes=True):
        """
        Plot channel delays as a bar chart.
        
        Args:
            highlight_extremes (bool): Whether to highlight min/max delays
            
        Returns:
            matplotlib.figure.Figure: The figure object
        """
        # Get all channels with delay measurements
        channels_with_delays = {}
        for ch in self.channels:
            if ('analysis' in self.channels[ch] and 
                'delay' in self.channels[ch]['analysis']):
                channels_with_delays[ch] = self.channels[ch]['analysis']['delay']['value']
        
        if not channels_with_delays:
            raise ValueError("No delay measurements found. Run calculate_delays first.")
            
        # Create labels and delay array
        ch_nums = sorted(list(channels_with_delays.keys()))
        labels = [f"CH{ch}" for ch in ch_nums]
        delays = [channels_with_delays[ch] for ch in ch_nums]
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        # Create bar plot
        bars = ax.bar(labels, delays)
        
        if highlight_extremes:
            # Find min (excluding 0 which is reference)
            non_zero_delays = [d for d in delays if d != 0]
            if non_zero_delays:
                min_idx = delays.index(min(non_zero_delays))
                max_idx = delays.index(max(delays))
                
                # Find second lowest if needed
                if min_idx == delays.index(0):  # If the minimum is the reference channel (0 delay)
                    # Make a copy and replace the lowest with infinity
                    temp_values = delays.copy()
                    temp_values[min_idx] = float('inf')
                    min_idx = temp_values.index(min(temp_values))
                
                # Highlight min/max bars
                bars[min_idx].set_color('lightgreen')
                bars[max_idx].set_color('tomato')
                
                # Add value labels
                ax.text(min_idx, delays[min_idx]+20, f"{delays[min_idx]:.0f}", 
                       ha='center', rotation=90)
                ax.text(max_idx, delays[max_idx]+20, f"{delays[max_idx]:.0f}", 
                       ha='center', rotation=90)
        
        # Calculate standard deviation
        std_dev = np.std([d for d in delays if not np.isnan(d) and d != 0])
        
        ax.set_xlabel("Channel")
        ax.set_ylabel("Delay Relative to Reference Channel [ps]")
        ax.set_title(f"{self.name} Unity Path Relative Channel Delays")
        ax.set_ylim(0, max(delays) * 1.2)  # Add 20% margin
        ax.legend([f'$\\sigma= {std_dev:.2f}$'])
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        return fig
    
    def plot_multiple_traces(self, channel, waveform_type='singles', max_traces=5, ax=None):
        """
        Plot multiple traces for a channel on the same axes.
        
        Args:
            channel (int): Channel number
            waveform_type (str): 'singles' or 'averages'
            max_traces (int): Maximum number of traces to plot
            ax (matplotlib.axes.Axes, optional): Axes to plot on
            
        Returns:
            matplotlib.axes.Axes: The plot axes
        """
        if channel not in self.channels or waveform_type not in self.channels[channel]:
            raise ValueError(f"No {waveform_type} data found for channel {channel}")
            
        traces = self.channels[channel][waveform_type]
        if not traces:
            raise ValueError(f"No traces found for channel {channel} in {wf_type}")
            
        # Create or use provided axes
        if ax is None:
            fig, ax = plt.subplots(figsize=(10, 6))
            
        # Limit number of traces to prevent overcrowding
        trace_nums = sorted(list(traces.keys()))[:max_traces]
        
        for i, trace_num in enumerate(trace_nums):
            df = traces[trace_num]
            if "output" in df.columns:
                label = f"Trace {trace_num}"
                ax.plot(df["time"] * 1e9, df["output"], label=label, alpha=0.7)
            elif "amplitude" in df.columns:
                label = f"Trace {trace_num}"
                ax.plot(df["time"] * 1e9, df["amplitude"], label=label, alpha=0.7)
            
        ax.set_xlabel("Time (ns)")
        ax.set_ylabel("Output")
        ax.set_title(f"{self.name} Channel {channel} - {waveform_type.capitalize()} Traces")
        ax.grid(True)
        ax.legend()
        
        return ax
    
    def load_data(self, file_pattern=None):
        """
        Load data files. To be implemented by subclasses.
        
        Args:
            file_pattern (str): Pattern to match files
            
        Returns:
            int: Number of channels loaded
        """
        raise NotImplementedError("Subclasses must implement load_data()")


In [17]:


class CASB1Processor(WaveformProcessor):
    """Processor for CASB1 board data."""
    
    def __init__(self):
        super().__init__(name="CASB1")
    
    def load_singles(self, path="../data/casb1/singles/C1--Trace--*.txt"):
        """
        Load CASB1 singles data from text files.
        """
        files = glob.glob(path)
        if not files:
            print(f"Warning: No files found matching pattern: {path}")
            return {}
            
        files_per_channel = {}
        
        for file in files:
            try:
                filename = os.path.basename(file)
                match = re.search(r'C(\d+)--Trace--(\d+)', filename)
                if match:
                    channel = int(match.group(1))
                    trace_num = int(match.group(2))
                else:
                    match = re.search(r'Trace--(\d+)', filename)
                    if match:
                        trace_num = int(match.group(1))
                        channel = 1  # Default if no channel in filename
                    else:
                        print(f"Could not extract info from {filename}, skipping")
                        continue
                
                # Load data
                df = pd.read_csv(file, skiprows=6, names=["time", "output"])
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                
                # Initialize channel structure if needed
                if channel not in self.channels:
                    self.channels[channel] = {}
                
                # Initialize singles dict if needed
                if 'singles' not in self.channels[channel]:
                    self.channels[channel]['singles'] = {}
                
                # Store the trace
                self.channels[channel]['singles'][trace_num] = df
                
                # Update counter
                if channel in files_per_channel:
                    files_per_channel[channel] += 1
                else:
                    files_per_channel[channel] = 1
                    
            except Exception as e:
                print(f"Error processing file {file}: {e}")
                
        # Print summary
        total_files = sum(files_per_channel.values())
        print(f"Loaded {total_files} singles files across {len(files_per_channel)} channels for {self.name}")
        
        return files_per_channel
    
    def load_averages(self, path="../data/casb1/averages/ch*.csv"):
        """
        Load CASB1 averages data from CSV files.
        """
        files = glob.glob(path)
        if not files:
            print(f"Warning: No files found matching pattern: {path}")
            return {}
            
        files_per_channel = {}
        
        for file in files:
            try:
                filename = os.path.basename(file)
                match = re.search(r'ch(\d+)', filename)
                if match:
                    channel = int(match.group(1))
                else:
                    print(f"Could not extract channel from {filename}, skipping")
                    continue
                
                # Extract trace number if available, otherwise use counter
                if channel in files_per_channel:
                    trace_num = files_per_channel[channel]
                else:
                    trace_num = 0
                
                # Load data
                df = pd.read_csv(file, skiprows=21, names=["time", "output", "input", "CH3"])
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                
                # Keep relevant columns
                df = df[["time", "output", "input"]] if "input" in df.columns else df[["time", "output"]]
                
                # Initialize channel structure if needed
                if channel not in self.channels:
                    self.channels[channel] = {}
                
                # Initialize averages dict if needed
                if 'averages' not in self.channels[channel]:
                    self.channels[channel]['averages'] = {}
                
                # Store the trace
                self.channels[channel]['averages'][trace_num] = df
                
                # Update counter
                if channel in files_per_channel:
                    files_per_channel[channel] += 1
                else:
                    files_per_channel[channel] = 1
                    
            except Exception as e:
                print(f"Error processing file {file}: {e}")
                
        # Print summary
        total_files = sum(files_per_channel.values())
        print(f"Loaded {total_files} averages files across {len(files_per_channel)} channels for {self.name}")
        
        return files_per_channel
    
    def load_data(self, singles_path=None, averages_path=None):
        """
        Load both singles and averages data.
        
        Args:
            singles_path (str, optional): Path to singles files
            averages_path (str, optional): Path to averages files
            
        Returns:
            tuple: Number of channels with singles data, number with averages data
        """
        # Load singles if path provided
        singles_result = {}
        if singles_path:
            singles_result = self.load_singles(singles_path)
        
        # Load averages if path provided
        averages_result = {}
        if averages_path:
            averages_result = self.load_averages(averages_path)
        
        return len(singles_result), len(averages_result)


In [18]:

class CASB2Processor(WaveformProcessor):
    """Processor for CASB2 board data."""
    
    def __init__(self):
        super().__init__(name="CASB2")
    
    def load_singles(self, path="../data/casb2/singles/ch*/tek*ALL.csv"):
        """
        Load CASB2 singles data from CSV files.
        """
        files = glob.glob(path)
        if not files:
            print(f"Warning: No files found matching pattern: {path}")
            return {}
            
        files_per_channel = {}
        
        for file in files:
            try:
                filename = os.path.basename(file)
                ch_dir = os.path.basename(os.path.dirname(file))
                
                # Try to extract channel from directory first
                ch_match = re.search(r'ch(\d+)', ch_dir)
                if ch_match:
                    channel = int(ch_match.group(1))
                else:
                    # Try from filename
                    ch_match = re.search(r'ch(\d+)', filename)
                    if ch_match:
                        channel = int(ch_match.group(1))
                    else:
                        print(f"Could not extract channel from {file}, skipping")
                        continue
                
                # Extract trace number if available
                trace_match = re.search(r'tek(\d+)ALL', filename)
                if trace_match:
                    trace_num = int(trace_match.group(1))
                else:
                    # Use a counter based on how many files we've seen for this channel
                    if channel in files_per_channel:
                        trace_num = files_per_channel[channel]
                    else:
                        trace_num = 0
                
                # Read the data
                df = pd.read_csv(file, skiprows=21, names=["time", "output", "input", "CH3"])
                
                # Ensure numeric values
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                
                # Keep only the columns we need
                df = df[["time", "output", "input"]] if "input" in df.columns else df[["time", "output"]]
                
                # Initialize channel structure if needed
                if channel not in self.channels:
                    self.channels[channel] = {}
                
                # Initialize singles dict if needed
                if 'singles' not in self.channels[channel]:
                    self.channels[channel]['singles'] = {}
                
                # Store trace with its number
                self.channels[channel]['singles'][trace_num] = df
                
                # Update counter
                if channel in files_per_channel:
                    files_per_channel[channel] += 1
                else:
                    files_per_channel[channel] = 1
                    
            except Exception as e:
                print(f"Error processing file {file}: {e}")
        
        # Print summary
        total_files = sum(files_per_channel.values())
        print(f"Loaded {total_files} singles files across {len(files_per_channel)} channels for {self.name}")
        
        return files_per_channel
    
    def load_averages(self, path="../data/casb2/averages/ch*/tek*ALL.csv"):
        """
        Load CASB2 averages data from CSV files.
        """
        files = glob.glob(path)
        if not files:
            print(f"Warning: No files found matching pattern: {path}")
            return {}
            
        files_per_channel = {}
        
        for file in files:
            try:
                filename = os.path.basename(file)
                ch_dir = os.path.basename(os.path.dirname(file))
                
                # Try to extract channel from directory first
                ch_match = re.search(r'ch(\d+)', ch_dir)
                if ch_match:
                    channel = int(ch_match.group(1))
                else:
                    # Try from filename
                    ch_match = re.search(r'ch(\d+)', filename)
                    if ch_match:
                        channel = int(ch_match.group(1))
                    else:
                        print(f"Could not extract channel from {file}, skipping")
                        continue
                
                # Extract trace number if available
                trace_match = re.search(r'tek(\d+)ALL', filename)
                if trace_match:
                    trace_num = int(trace_match.group(1))
                else:
                    # Use a counter based on how many files we've seen for this channel
                    if channel in files_per_channel:
                        trace_num = files_per_channel[channel]
                    else:
                        trace_num = 0
                
                # Read the data
                df = pd.read_csv(file, skiprows=21, names=["time", "output", "input", "CH3"])
                
                # Ensure numeric values
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                
                # Keep only the columns we need
                df = df[["time", "output", "input"]] if "input" in df.columns else df[["time", "output"]]
                
                # Initialize channel structure if needed
                if channel not in self.channels:
                    self.channels[channel] = {}
                
                # Initialize averages dict if needed
                if 'averages' not in self.channels[channel]:
                    self.channels[channel]['averages'] = {}
                
                # Store trace with its number
                self.channels[channel]['averages'][trace_num] = df
                
                # Update counter
                if channel in files_per_channel:
                    files_per_channel[channel] += 1
                else:
                    files_per_channel[channel] = 1
                    
            except Exception as e:
                print(f"Error processing file {file}: {e}")
        
        # Print summary
        total_files = sum(files_per_channel.values())
        print(f"Loaded {total_files} averages files across {len(files_per_channel)} channels for {self.name}")
        
        return files_per_channel
    
    def load_data(self, singles_path=None, averages_path=None):
        """
        Load both singles and averages data.
        
        Args:
            singles_path (str, optional): Path to singles files
            averages_path (str, optional): Path to averages files
            
        Returns:
            tuple: Number of channels with singles data, number with averages data
        """
        # Load singles if path provided
        singles_result = {}
        if singles_path:
            singles_result = self.load_singles(singles_path)
        else:
            singles_result = self.load_singles()  # Use default path
        
        # Load averages if path provided
        averages_result = {}
        if averages_path:
            averages_result = self.load_averages(averages_path)
        else:
            averages_result = self.load_averages()  # Use default path
        
        return len(singles_result), len(averages_result)



In [19]:

class MTCAProcessor(WaveformProcessor):
    """Processor for MTCA board data."""
    
    def __init__(self):
        super().__init__(name="MTCA")
    
    def load_singles(self, path="../data/mtca1/singles/C4--Trace--*.txt"):
        """
        Load MTCA singles data from text files.
        """
        files = glob.glob(path)
        if not files:
            print(f"Warning: No files found matching pattern: {path}")
            return {}
            
        files_per_channel = {}
        
        for file in files:
            try:
                filename = os.path.basename(file)
                
                # Extract channel and trace info
                match = re.search(r'C(\d+)--Trace--(\d+)', filename)
                if match:
                    channel = int(match.group(1))
                    trace_num = int(match.group(2))
                else:
                    match = re.search(r'Trace--(\d+)', filename)
                    if match:
                        trace_num = int(match.group(1))
                        channel = 4  # Default if not specified
                    else:
                        print(f"Could not extract info from {filename}, skipping")
                        continue
                
                # Load data - adjust skiprows based on your file format
                df = pd.read_csv(file, skiprows=6, names=["time", "output"])
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                
                # Initialize channel structure if needed
                if channel not in self.channels:
                    self.channels[channel] = {}
                
                # Initialize singles dict if needed
                if 'singles' not in self.channels[channel]:
                    self.channels[channel]['singles'] = {}
                
                # Store the trace
                self.channels[channel]['singles'][trace_num] = df
                
                # Update counter
                if channel in files_per_channel:
                    files_per_channel[channel] += 1
                else:
                    files_per_channel[channel] = 1
                    
            except Exception as e:
                print(f"Error processing file {file}: {e}")
                
        # Print summary
        total_files = sum(files_per_channel.values())
        print(f"Loaded {total_files} singles files across {len(files_per_channel)} channels for {self.name}")
        
        return files_per_channel
    
    def load_averages(self, path="../data/mtca1/averages/ch*.csv"):
        """
        Load MTCA averages data from CSV files.
        """
        files = glob.glob(path)
        if not files:
            print(f"Warning: No files found matching pattern: {path}")
            return {}
            
        files_per_channel = {}
        
        for file in files:
            try:
                filename = os.path.basename(file)
                match = re.search(r'ch(\d+)', filename)
                if match:
                    channel = int(match.group(1))
                else:
                    print(f"Could not extract channel from {filename}, skipping")
                    continue
                
                # Extract trace number if available, otherwise use counter
                if channel in files_per_channel:
                    trace_num = files_per_channel[channel]
                else:
                    trace_num = 0
                
                # Load data - adjust skiprows and column names based on your file format
                df = pd.read_csv(file)
                
                # Find time and output columns
                time_col = None
                output_col = None
                
                for col in df.columns:
                    if col.lower() in ("time", "times"):
                        time_col = col
                    elif col.lower() in ("output", "amplitude", "waveform"):
                        output_col = col
                
                if not time_col or not output_col:
                    print(f"Could not identify time/output columns in {filename}, using first two columns")
                    time_col = df.columns[0]
                    output_col = df.columns[1]
                
                # Create new dataframe with standardized column names
                new_df = pd.DataFrame({
                    "time": pd.to_numeric(df[time_col], errors='coerce'),
                    "output": pd.to_numeric(df[output_col], errors='coerce')
                })
                
                # Initialize channel structure if needed
                if channel not in self.channels:
                    self.channels[channel] = {}
                
                # Initialize averages dict if needed
                if 'averages' not in self.channels[channel]:
                    self.channels[channel]['averages'] = {}
                
                # Store the trace
                self.channels[channel]['averages'][trace_num] = new_df
                
                # Update counter
                if channel in files_per_channel:
                    files_per_channel[channel] += 1
                else:
                    files_per_channel[channel] = 1
                    
            except Exception as e:
                print(f"Error processing file {file}: {e}")
                
        # Print summary
        total_files = sum(files_per_channel.values())
        print(f"Loaded {total_files} averages files across {len(files_per_channel)} channels for {self.name}")
        
        return files_per_channel
    
    def load_data(self, singles_path=None, averages_path=None):
        """
        Load both singles and averages data.
        
        Args:
            singles_path (str, optional): Path to singles files
            averages_path (str, optional): Path to averages files
            
        Returns:
            tuple: Number of channels with singles data, number with averages data
        """
        # Load singles if path provided
        singles_result = {}
        if singles_path:
            singles_result = self.load_singles(singles_path)
        else:
            singles_result = self.load_singles()  # Use default path
        
        # Load averages if path provided
        averages_result = {}
        if averages_path:
            averages_result = self.load_averages(averages_path)
        else:
            try:
                averages_result = self.load_averages()  # Use default path
            except Exception as e:
                print(f"Warning: Could not load averages with default path: {e}")
        
        return len(singles_result), len(averages_result)


In [20]:


def plot_combined_delays(boards, waveform_type='averages', trace_index=0, highlight_extremes=True):
    """
    Create a combined bar plot showing delays from multiple boards.
    
    Args:
        boards (list): List of WaveformProcessor instances
        waveform_type (str): 'singles' or 'averages'
        trace_index (int): Index of the trace to use
        highlight_extremes (bool): Whether to highlight min/max delays
        
    Returns:
        matplotlib.figure.Figure: The figure object
    """
    # Make sure all boards have delay analysis
    for board in boards:
        channels_with_delays = 0
        for ch in board.channels:
            if ('analysis' in board.channels[ch] and 
                'delay' in board.channels[ch]['analysis']):
                channels_with_delays += 1
        
        if channels_with_delays == 0:
            print(f"Calculating delays for {board.name} using {waveform_type}[{trace_index}]")
            try:
                board.calculate_delays(waveform_type, trace_index)
            except Exception as e:
                print(f"Warning: Could not calculate delays for {board.name}: {e}")
    
    # Get all channels across all boards
    all_channels = set()
    for board in boards:
        all_channels.update(board.channels.keys())
    all_channels = sorted(list(all_channels))
    
    # Create labels
    labels = [f"CH{ch}" for ch in all_channels]
    
    # Set up the plot
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # Calculate positions for bars
    x = np.arange(len(labels))
    width = 0.8 / len(boards)  # Width of the bars, adjusted for number of boards
    
    # Plot bars for each board
    legend_items = []
    
    for i, board in enumerate(boards):
        # Create dictionary mapping channel to delay
        delay_dict = {}
        for ch in board.channels:
            if ('analysis' in board.channels[ch] and 
                'delay' in board.channels[ch]['analysis']):
                delay_dict[ch] = board.channels[ch]['analysis']['delay']['value']
        
        # Fill in delays for all channels (use NaN for missing channels)
        delays = [delay_dict.get(ch, np.nan) for ch in all_channels]
        
        # Calculate position adjustment
        pos_adjustment = width * (i - (len(boards) - 1) / 2)
        
        # Create bars
        bars = ax.bar(x + pos_adjustment, delays, width, label=board.name)
        
        if highlight_extremes:
            # Find non-NaN values
            valid_delays = [(idx, d) for idx, d in enumerate(delays) if not np.isnan(d)]
            if valid_delays:
                valid_indices, valid_values = zip(*valid_delays)
                
                # Find min (excluding 0 which might be reference)
                non_zero_values = [d for d in valid_values if d != 0]
                if non_zero_values:
                    min_val = min(non_zero_values)
                    min_idx = valid_indices[valid_values.index(min_val)]
                    max_val = max(valid_values)
                    max_idx = valid_indices[valid_values.index(max_val)]
                    
                    # Highlight min/max bars
                    bars[min_idx].set_color('lightgreen')
                    bars[max_idx].set_color('tomato')
                    
                    # Add value labels
                    ax.text(x[min_idx] + pos_adjustment, delays[min_idx]+20, 
                           f"{delays[min_idx]:.0f}", ha='center', rotation=90, fontsize=8)
                    ax.text(x[max_idx] + pos_adjustment, delays[max_idx]+20, 
                           f"{delays[max_idx]:.0f}", ha='center', rotation=90, fontsize=8)
        
        # Calculate standard deviation for legend
        valid_delays = [d for d in delays if not np.isnan(d) and d != 0]
        std_dev = np.std(valid_delays) if valid_delays else 0
        legend_items.append(f'{board.name} ($\\sigma= {std_dev:.2f}$)')
    
    # Customize the plot
    ax.set_xlabel("Channel")
    ax.set_ylabel("Delay Relative to Reference Channel [ps]")
    ax.set_title(f"Unity Path Relative Channel Delays Comparison ({waveform_type})")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=90)
    
    # Find appropriate y-limit - skip NaN and inf values
    all_delays = []
    for board in boards:
        for ch in board.channels:
            if ('analysis' in board.channels[ch] and 
                'delay' in board.channels[ch]['analysis']):
                value = board.channels[ch]['analysis']['delay']['value']
                if not np.isnan(value) and not np.isinf(value):
                    all_delays.append(value)
    
    if all_delays:
        max_delay = max(all_delays)
        ax.set_ylim(0, max_delay * 1.2)  # Add 20% margin
    
    # Add legend with standard deviations
    ax.legend(legend_items)
    
    # Add grid
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    return fig


In [23]:


# Example of how to use the classes:

# Initialize processors
casb1 = CASB1Processor()
casb2 = CASB2Processor()
mtca = MTCAProcessor()

# Load data
casb1.load_data()
casb2.load_data()
mtca.load_data()

# Calculate rise times using singles data
casb1.calculate_all_rise_times(waveform_type='singles', trace_index=0)
casb2.calculate_all_rise_times(waveform_type='singles', trace_index=0)
mtca.calculate_all_rise_times(waveform_type='singles', trace_index=0)

# Calculate delays using averages data
casb1.calculate_delays(waveform_type='averages', trace_index=0)
casb2.calculate_delays(waveform_type='averages', trace_index=0)
mtca.calculate_delays(waveform_type='averages', trace_index=0)

# Calculate gains using averages data
casb1.calculate_gains(waveform_type='averages', trace_index=0)
casb2.calculate_gains(waveform_type='averages', trace_index=0)
mtca.calculate_gains(waveform_type='averages', trace_index=0)

# Plot a sample waveform
casb1.plot_waveform(channel=1, waveform_type='singles', trace_index=0, show_thresholds=True)

# Plot delays for individual boards
casb1.plot_delays(highlight_extremes=True)
casb2.plot_delays(highlight_extremes=True)

# Plot combined delays
plot_combined_delays([casb1, casb2, mtca], waveform_type='averages', trace_index=0)

Loaded 401 singles files across 20 channels for CASB2
Loaded 35 singles files across 1 channels for MTCA
Channel 1: Rise time = 0.00 ns
Channel 2: Rise time = 0.00 ns
Channel 6: Rise time = 0.00 ns
Channel 14: Rise time = 0.00 ns
Channel 17: Rise time = 0.00 ns
Channel 10: Rise time = 0.00 ns
Channel 20: Rise time = 0.00 ns
Channel 9: Rise time = 0.00 ns
Channel 19: Rise time = 0.00 ns
Channel 3: Rise time = 0.00 ns
Channel 11: Rise time = 0.00 ns
Channel 13: Rise time = 0.00 ns
Channel 5: Rise time = 0.00 ns
Channel 8: Rise time = 0.00 ns
Channel 12: Rise time = 0.00 ns
Channel 16: Rise time = 0.00 ns
Channel 4: Rise time = 0.00 ns
Channel 15: Rise time = 0.00 ns
Channel 7: Rise time = 0.00 ns
Channel 18: Rise time = 0.00 ns
Channel 4: Rise time = nan ns
No rise time measurements found. Calculating...


ValueError: Failed to calculate rise times. Cannot calculate delays.