In [73]:
# Restructured Analysis Notebook
# This code implements a class-based approach for analyzing waveform data from CASB1, CASB2, and MTCA boards

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import re
import statistics
from scipy.interpolate import interp1d
from scipy import signal

In [74]:


class WaveformProcessor:
    """
    Base class for storing and processing waveform data from different boards
    """
    def __init__(self, name=None):
        self.name = name or "Unnamed"
        self.channels = {}

    def get_available_channels(self):
        """
        Get the channels that have data stored in the dictionary
        """
        return sorted(list(self.channels.keys()))

    def get_channel_waveform(self, waveform_type='singles', channel=None, trace_index=0):
        """
        Get a specific waveform of a channel
        """
        if channel not in self.channels:
            raise ValueError(f"Channel {channel} not found in waveforms")  

        if waveform_type not in self.channels[channel].keys():
            raise ValueError(f"Channel {channel} does not have {waveform_type} data")

        if trace_index not in self.channels[channel][waveform_type].keys():
            raise ValueError(f"Channel {channel} does not have trace {trace_index} in {waveform_type} data. There are {len(self.channels[channel][waveform_type].keys())} traces available.")

        return self.channels[channel][waveform_type][trace_index]

    def get_pedestal(self, data, baseline_start_pct=0, baseline_end_pct=0.2):
        """
        Calculate the pedestal (baseline) of a waveform within a specified range of the data
        """
        start_idx = int(len(data) * baseline_start_pct)
        end_idx = int(len(data) * baseline_end_pct)
        return np.mean(data[start_idx:end_idx+1])

    def getPeakIndex(self, data, baseline_start_pct=0, baseline_end_pct=0.2, threshold=0.1):
        """
        Get the index of the peak value in a waveform above a specified threshold
        """
        peak_value = 0
        peak_index = 0
        threshold_index = 0
        crossed_threshold = False
        counter = 0
        pedestal = self.get_pedestal(data, baseline_start_pct, baseline_end_pct)
        for i in range(len(data)):
            if not crossed_threshold and data[i]-pedestal>threshold:
                crossed_threshold = True
                threshold_index = i
            if crossed_threshold and data[i]>peak_value:
                peak_value = data[i]
                peak_index = i
            elif crossed_threshold and data[i]<=peak_value:
                counter += 1
            if crossed_threshold and counter>1:
                break
        return peak_index, threshold_index
    
    def calculate_rise_time(self, channel, waveform_type='singles', trace_index=0, baseline_start_pct=0, baseline_end_pct=0.2, threshold=0.01, low_pct=0.1, high_pct=0.9):
        """
        Calculate rise time for a given channel.
        """
        df = self.get_channel_waveform(waveform_type, channel, trace_index)
        time = df["time"].values * 1e9 # Convert to ns
        signal = df["output"].values * 1e3 # Convert to mV

        pedestal = self.get_pedestal(signal, baseline_start_pct, baseline_end_pct)
        peak_index,threshold_index = self.getPeakIndex(signal, baseline_start_pct, baseline_end_pct, threshold)
        amplitude = signal[peak_index] - pedestal
        low_threshold = amplitude * low_pct - pedestal 
        high_threshold = amplitude * high_pct - pedestal

        interp_func = interp1d(time, signal, kind='linear') 
        interp_time = np.linspace(time[0], time[-1], num=10000)
        interp_signal = interp_func(interp_time)
        signal_above_low = interp_signal >= low_threshold
        signal_above_high = interp_signal >= high_threshold
        if np.any(signal_above_low):
            t_low_index = np.argmax(signal_above_low)
            t_low = interp_time[t_low_index]
        else:
            t_low = np.nan

        if np.any(signal_above_high):
            t_high_index = np.argmax(signal_above_high)
            t_high = interp_time[t_high_index]
        else:
            t_high = np.nan

        rise_time = t_high - t_low

        if 'analysis' not in self.channels[channel]:
            self.channels[channel]['analysis'] = {}

        self.channels[channel]['analysis']['rise_time'] = {
            "value": rise_time,
            "source": waveform_type,
            "trace_index": trace_index,
            "t_low": t_low,
            "t_high": t_high,
            "low_threshold": low_threshold,
            "high_threshold": high_threshold,
            "amplitude": amplitude,
            "pedestal": pedestal
        }
        return rise_time, t_low, t_high
    
    def calculate_all_rise_times(self, waveform_type='singles', trace_index=0, baseline_start=0, baseline_end=0.2, threshold=0.005, low_pct=0.1, high_pct=0.9):
        """
        Calculate rise times for all loaded channels.
        """
        results = {}
        for channel in self.channels:
            try:
                rt, t_low, t_high = self.calculate_rise_time(channel, baseline_start, baseline_end, threshold, low_pct, high_pct)
                results[channel] = rt
                print(f"Channel {channel}: Rise time = {rt:.2f} ns")
            except Exception as e:
                print(f"Error processing channel {channel}: {e}")
                results[channel] = np.nan
        return results
    
    def calculate_delays(self, waveform_type='averages', trace_index=0, reference_channel=None):
        """
        Calculate delays relative to fasted channel.
        """
        # Ensure rise times are calculated
        if not self.rise_times:
            self.calculate_all_rise_times()
        # Find reference channel (channel with minimum rise time)
        if reference_channel is None:
            min_rt = float('inf')
            for ch, data in self.rise_times.items():
                if data["rise_time"] < min_rt:
                    min_rt = data["rise_time"]
                    reference_channel = ch
        if reference_channel not in self.rise_times:
            raise ValueError(f"Reference channel {reference_channel} not found or has no rise time data")
        # Get reference t_low time
        ref_t_low = self.rise_times[reference_channel]["t_low"]
        # Calculate delays
        delays = {}
        for ch in self.channels:
            if ch in self.rise_times and not np.isnan(self.rise_times[ch]["t_low"]):
                # Calculate delay in picoseconds
                delay_ps = (self.rise_times[ch]["t_low"] - ref_t_low) * 1e12
                delays[ch] = delay_ps
            else:
                delays[ch] = np.nan
        self.delays = [delays[ch] for ch in self.channels]
        return delays
    
    def plot_waveform(self, channel, show_thresholds=True):
        """
        Plot a waveform for a specific channel.
        """
        if channel not in self.waveforms:
            raise ValueError(f"Channel {channel} not found in waveforms")
            
        df = self.waveforms[channel]
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        time_ns = df["time"].values * 1e9  # Convert to nanoseconds
        signal = df["amplitude"].values * 1e-3 # Convert to mV
        
        if channel in self.rise_times:
            pedestal = self.rise_times[channel]["pedestal"]
            signal_adj = signal - pedestal
        else:
            pedestal = self._get_pedestal(signal)
            signal_adj = signal - pedestal
            
        ax.plot(time_ns, signal_adj, label=f"{self.name} Ch{channel}")
        
        # Show threshold markers if requested and available
        if show_thresholds and channel in self.rise_times:
            rt_data = self.rise_times[channel]
            t_low_ns = rt_data["t_low"] * 1e9
            t_high_ns = rt_data["t_high"] * 1e9
            low_threshold = rt_data["low_threshold"]
            high_threshold = rt_data["high_threshold"]
            
            # Add vertical lines at threshold crossings
            ax.axvline(x=t_low_ns, color='green', linestyle='--', 
                      label=f'tLow = {t_low_ns:.2f} ns')
            ax.axvline(x=t_high_ns, color='red', linestyle='--', 
                      label=f'tHigh = {t_high_ns:.2f} ns')
            
            # Add horizontal lines at threshold levels
            ax.axhline(y=low_threshold, color='green', linestyle=':')
            ax.axhline(y=high_threshold, color='red', linestyle=':')
            
            # Add rise time text
            rise_time_ns = rt_data["rise_time"] * 1e9
            ax.text(0.7, 0.9, f"Rise Time: {rise_time_ns:.2f} ns", 
                   transform=ax.transAxes, fontsize=12, 
                   bbox=dict(facecolor='white', alpha=0.7))
            
        ax.set_xlabel("Time (ns)")
        ax.set_ylabel("Amplitude")
        ax.set_title(f"{self.name} Channel {channel} Waveform")
        ax.grid(True)
        ax.legend()
        
        plt.tight_layout()
        return fig
    
    def plot_delays(self, highlight_extremes=True):
        """
        Plot channel delays as a bar chart.
        
        Args:
            highlight_extremes (bool): Whether to highlight min/max delays
            
        Returns:
            matplotlib.figure.Figure: The figure object
        """
        if not self.delays:
            self.calculate_delays()
            
        # Create label list
        labels = [f"CH{ch}" for ch in self.channels]
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        # Create bar plot
        bars = ax.bar(labels, self.delays)
        
        if highlight_extremes:
            # Find min (excluding 0 which is reference)
            non_zero_delays = [d for d in self.delays if d != 0]
            if non_zero_delays:
                min_idx = self.delays.index(min(non_zero_delays))
                max_idx = self.delays.index(max(self.delays))
                
                # Find second lowest if needed
                if min_idx == 0:  # If the minimum is the reference channel (0 delay)
                    # Make a copy and replace the lowest with infinity
                    temp_values = self.delays.copy()
                    temp_values[min_idx] = float('inf')
                    min_idx = temp_values.index(min(temp_values))
                
                # Highlight min/max bars
                bars[min_idx].set_color('lightgreen')
                bars[max_idx].set_color('tomato')
                
                # Add value labels
                ax.text(min_idx, self.delays[min_idx]+20, f"{self.delays[min_idx]:.0f}", 
                       ha='center', rotation=90)
                ax.text(max_idx, self.delays[max_idx]+20, f"{self.delays[max_idx]:.0f}", 
                       ha='center', rotation=90)
        
        # Calculate standard deviation
        std_dev = np.std([d for d in self.delays if not np.isnan(d)])
        
        ax.set_xlabel("Channel")
        ax.set_ylabel("Delay Relative to Shortest Channel [ps]")
        ax.set_title(f"{self.name} Unity Path Relative Channel Delays")
        ax.set_ylim(0, max(self.delays) * 1.2)  # Add 20% margin
        ax.set_xticklabels(labels, rotation=90)
        ax.legend([f'$\\sigma= {std_dev:.2f}$'])
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        return fig
        
    def load_data(self, file_pattern=None):
        """
        Load data files. To be implemented by subclasses.
        
        Args:
            file_pattern (str): Pattern to match files
            
        Returns:
            int: Number of channels loaded
        """
        raise NotImplementedError("Subclasses must implement load_data()")



In [75]:

class CASB1Processor(WaveformProcessor):
    """Processor for CASB1 board data."""
    
    def __init__(self):
        super().__init__(name="CASB1")
    
    def load_data(self, path="../data/casb1/singles/C1--Trace--*.txt"):
        """
        Load CASB1 data from text files.
        """
        files = glob.glob(path)
        if not files:
            raise ValueError(f"No files found matching pattern: {path}")
        self.waveforms = {}
        self.channels = []
        for file in files:
            try:
                filename = os.path.basename(file)
                match = re.search(r'Trace--(\d+)', filename)
                channel = int(match.group(1))
                df = pd.read_csv(file, skiprows=6, names=["time", "output"])
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                self.waveforms[channel] = df
                if channel not in self.channels:
                    self.channels.append(channel)
                    
            except Exception as e:
                print(f"Error processing file {file}: {e}")
        
        # Sort channels numerically
        self.channels.sort()
        
        print(f"Loaded {len(self.channels)} channels for {self.name}")
        return len(self.channels)



In [76]:

class CASB2Processor(WaveformProcessor):
    """Processor for CASB2 board data."""
    
    def __init__(self):
        super().__init__(name="CASB2")
    
    def load_data(self, path="../data/casb2/singles/ch*/tek*ALL.csv"):
        """
        Load CASB2 data from CSV files.
        """
        files = glob.glob(path)
        if not files:
            raise ValueError(f"No files found matching pattern: {path}")
        self.waveforms = {}
        self.channels = []
        for file in files:
            try:
                filename = os.path.basename(file)
                match = re.search(r'ch(\d+)', filename)
                channel = int(match.group(1))
                df = pd.read_csv(file, skiprows=21, names=["time", "output", "input", "CH3"])
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                df = df[["time", "output", "input"]]
                self.waveforms[channel] = df
                if channel not in self.channels:
                    self.channels.append(channel)
            except Exception as e:
                print(f"Error processing file {file}: {e}")
        
        # Sort channels numerically
        self.channels.sort()
        
        print(f"Loaded {len(self.channels)} channels for {self.name}")
        return len(self.channels)



In [77]:

class MTCAProcessor(WaveformProcessor):
    """Processor for MTCA board data."""
    
    def __init__(self):
        super().__init__(name="MTCA")
    
    def load_data(self, path="../data/mtca1/singles/C4--Trace--*.txt"):
        """
        Load MTCA data from CSV files.
        """
        files = glob.glob(path)
        
        if not files:
            raise ValueError(f"No files found matching pattern: {path}")
            
        # Clear existing data
        self.waveforms = {}
        self.channels = []
        
        for file in files:
            try:
                filename = os.path.basename(file)
                match = re.search(r'Trace--(\d+)ALL', filename)
                channel = int(match.group(1))
                df = pd.read_csv(file, skiprows=6, names=["time", "output"])
                for col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                self.waveforms[channel] = df
                if channel not in self.channels:
                    self.channels.append(channel)
            except Exception as e:
                print(f"Error processing file {file}: {e}")
        # Sort channels numerically
        self.channels.sort()
        
        print(f"Loaded {len(self.channels)} channels for {self.name}")
        return len(self.channels)


In [78]:


def plot_combined_delays(boards, highlight_extremes=True):
    """
    Create a combined bar plot showing delays from multiple boards.
    
    Args:
        boards (list): List of WaveformProcessor instances
        highlight_extremes (bool): Whether to highlight min/max delays
        
    Returns:
        matplotlib.figure.Figure: The figure object
    """
    # Ensure all boards have calculated delays
    for board in boards:
        if not board.delays:
            board.calculate_delays()
    
    # Get common set of channels
    channels = []
    for board in boards:
        channels.extend(board.channels)
    channels = sorted(list(set(channels)))
    
    # Create labels
    labels = [f"CH{ch}" for ch in channels]
    
    # Set up the plot
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # Calculate positions for bars
    x = np.arange(len(labels))
    width = 0.8 / len(boards)  # Width of the bars, adjusted for number of boards
    
    # Plot bars for each board
    legend_items = []
    
    for i, board in enumerate(boards):
        # Create dictionary mapping channel to delay
        delay_dict = {ch: delay for ch, delay in zip(board.channels, board.delays)}
        
        # Fill in delays for all channels (use NaN for missing channels)
        delays = [delay_dict.get(ch, np.nan) for ch in channels]
        
        # Calculate position adjustment
        pos_adjustment = width * (i - (len(boards) - 1) / 2)
        
        # Create bars
        bars = ax.bar(x + pos_adjustment, delays, width, label=board.name)
        
        if highlight_extremes:
            # Find non-NaN values
            valid_delays = [(idx, d) for idx, d in enumerate(delays) if not np.isnan(d)]
            if valid_delays:
                valid_indices, valid_values = zip(*valid_delays)
                
                # Find min (excluding 0 which might be reference)
                non_zero_values = [d for d in valid_values if d != 0]
                if non_zero_values:
                    min_val = min(non_zero_values)
                    min_idx = valid_indices[valid_values.index(min_val)]
                    max_val = max(valid_values)
                    max_idx = valid_indices[valid_values.index(max_val)]
                    
                    # Highlight min/max bars
                    bars[min_idx].set_color('lightgreen')
                    bars[max_idx].set_color('tomato')
                    
                    # Add value labels
                    ax.text(x[min_idx] + pos_adjustment, delays[min_idx]+20, 
                           f"{delays[min_idx]:.0f}", ha='center', rotation=90, fontsize=8)
                    ax.text(x[max_idx] + pos_adjustment, delays[max_idx]+20, 
                           f"{delays[max_idx]:.0f}", ha='center', rotation=90, fontsize=8)
        
        # Calculate standard deviation for legend
        valid_delays = [d for d in delays if not np.isnan(d)]
        std_dev = np.std(valid_delays) if valid_delays else 0
        legend_items.append(f'{board.name} ($\\sigma= {std_dev:.2f}$)')
    
    # Customize the plot
    ax.set_xlabel("Channel")
    ax.set_ylabel("Delay Relative to Shortest Channel [ps]")
    ax.set_title("Unity Path Relative Channel Delays Comparison")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=90)
    
    # Find appropriate y-limit
    max_delay = max([max([d for d in board.delays if not np.isnan(d)]) for board in boards])
    ax.set_ylim(0, max_delay * 1.2)  # Add 20% margin
    
    # Add legend with standard deviations
    ax.legend(legend_items)
    
    # Add grid
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    return fig



In [79]:

# Initialize processors
casb1 = CASB1Processor()
casb2 = CASB2Processor()
mtca1 = MTCAProcessor()

# Load data
casb1.load_data();
casb2.load_data();
mtca1.load_data();

# Calculate rise times
casb1.calculate_all_rise_times(low_pct=0.1, high_pct=0.9)
casb2.calculate_all_rise_times(low_pct=0.1, high_pct=0.9)
# mtca.calculate_all_rise_times()  # Uncomment if you have MTCA data

# Calculate delays
casb1.calculate_delays()
casb2.calculate_delays()
# mtca.calculate_delays()  # Uncomment if you have MTCA data

# Plot a sample waveform
casb1.plot_waveform(1, show_thresholds=True)

# Plot delays for individual boards
casb1.plot_delays()
casb2.plot_delays()

# Plot combined delays
plot_combined_delays([casb1, casb2])  # Add mtca to the list if available

Loaded 30 channels for CASB1
Error processing file ../data/casb2/singles/ch1/tek0001ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0007ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0018ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0012ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0000ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0002ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0016ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0004ALL.csv: 'NoneType' object has no attribute 'group'
Error processing file ../data/casb2/singles/ch1/tek0010ALL.csv: 'NoneType' object has no attribute 'group'
Error pr

AttributeError: 'CASB1Processor' object has no attribute 'rise_times'