In [None]:
'''
This code runs using my cusom overlay and aims to output a simple square wave at ~10 MHz.
Since this right at the cutoff frequency of the Baluns, the ouput ends up looking like the blue signal 
shown in "squareWave.png". To fix this we do a basic "digital pre-distortion" to get the desired output.
The correction applied to the channel 0 is half that of channel 1 since it only passes through the output balun, while
channel 1 passes through both the output (DAC) balun and input (ADC) balun. In this testing I had channel 0 connected to 
an oscilloscope while channel 1 was looped back to an ADC.
'''

import numpy as np
import time
from scipy.signal import square, correlate

try:
    from mulitple_output import SeparateDacOverlay
except ImportError:
    try:
        from multiple_output import SeparateDacOverlay
    except ImportError:
        print("CRITICAL: Could not import SeparateDacOverlay. Check file location.")
        exit()
BITFILE = './Working_Overlays/multi_dac_4gsps.bit'
# Initialize Overlay
ol = SeparateDacOverlay(BITFILE)

In [2]:
# Parameters
DAC_SR = 4.0e9          # 4 GS/s
TARGET_FREQ = 10.15e6      # 10 MHz
DAC_AMP = 2**15 - 1   # Max amplitude (int16)
ILC_ITERATIONS = 15 
LEARNING_RATE = 0.5 

In [None]:
# Get hardware buffer size
N = ol.dac0_bram.shape[0]

# Adjust Frequency to fit exactly integer cycles in N samples
# This prevents "clicking" or phase jumps when the DAC buffer loops
num_cycles = np.round((TARGET_FREQ * N) / DAC_SR)
actual_freq = (num_cycles * DAC_SR) / N
print(f"Frequency adjusted to {actual_freq/1e6:.4f} MHz to fit buffer perfectly.")

t = np.arange(N) / DAC_SR
target_wave = DAC_AMP * square(2 * np.pi * actual_freq * t)

# Initialize the DAC output guess (Start with the ideal square)
dac_wave_float = target_wave.copy()

In [None]:

def get_aligned_error(target, measured):
    """
    Aligns the measured signal to the target using cross-correlation
    and returns the error signal.
    """
    # 1. Remove DC offset for correlation
    t_ac = target - np.mean(target)
    m_ac = measured - np.mean(measured)
    
    # 2. Find time lag
    # This finds where the two signals overlap best
    correlation = correlate(m_ac, t_ac, mode='full')
    lag = np.argmax(correlation) - (len(t_ac) - 1)
    
    # 3. Shift measured signal to line up with target
    # We roll the measured signal back by the lag amount
    m_aligned = np.roll(measured, -lag)
    
    # 4. Amplitude Scaling (Gain Correction)
    # We want to fix the SHAPE (droop), not the attenuation. 
    # So we scale the measured signal to roughly match the target's power.
    gain_scale = np.std(target) / (np.std(m_aligned) + 1e-12)
    m_scaled = m_aligned * gain_scale
    
    # 5. Calculate Error
    return target - m_scaled, m_scaled

Frequency adjusted to 10.1318 MHz to fit buffer perfectly.


In [None]:

import matplotlib.pyplot as plt
import numpy as np
import time

channel_0_corectionFactor = 0.55

print(f"Starting Pre-distortion for {ILC_ITERATIONS} iterations...")

# Target DAC Amplitude: Full 16-bit signed range (2^15 - 1)
DAC_MAX = 2**15 - 1
PERCENTILE_THRESHOLD = 95.  # Clip the top 0.5% (spurious peaks)

initial_measured = None  # Variable to store the "before" snapshot

# --- SPLIT CHANNELS ---
# Initialize both channels from the starting wave.
# Ch1 = Master (Connected to ADC, Full Correction)
# Ch0 = Slave (Half Correction)
dac_wave_float_ch0 = dac_wave_float.copy()
dac_wave_float_ch1 = dac_wave_float.copy()

for i in range(ILC_ITERATIONS):
    # --- A. Prepare and Load DACs (Independent Normalization) ---
    
    # --- PROCESS CHANNEL 1 (Master) ---
    # 1. Soft-Clip Spurious Peaks (Ch1)
    p_limit_ch1 = np.percentile(np.abs(dac_wave_float_ch1), PERCENTILE_THRESHOLD)
    dac_wave_float_ch1 = np.clip(dac_wave_float_ch1, -p_limit_ch1, p_limit_ch1)
    
    # 2. Maximize Volume (Normalize Ch1)
    current_max_ch1 = np.max(np.abs(dac_wave_float_ch1))
    if current_max_ch1 > 1e-8:
        scale_factor_ch1 = DAC_MAX / current_max_ch1
        dac_wave_float_ch1 = dac_wave_float_ch1 * scale_factor_ch1
        
    # 3. Convert to int16 (Ch1)
    wave_i16_ch1 = np.clip(dac_wave_float_ch1, -32767, 32767).astype(np.int16)

    # --- PROCESS CHANNEL 0 (Slave) ---
    # 1. Soft-Clip Spurious Peaks (Ch0)
    p_limit_ch0 = np.percentile(np.abs(dac_wave_float_ch0), PERCENTILE_THRESHOLD)
    dac_wave_float_ch0 = np.clip(dac_wave_float_ch0, -p_limit_ch0, p_limit_ch0)
    
    # 2. Maximize Volume (Normalize Ch0 independently)
    current_max_ch0 = np.max(np.abs(dac_wave_float_ch0))
    if current_max_ch0 > 1e-8:
        scale_factor_ch0 = DAC_MAX / current_max_ch0
        dac_wave_float_ch0 = dac_wave_float_ch0 * scale_factor_ch0
        
    # 3. Convert to int16 (Ch0)
    wave_i16_ch0 = np.clip(dac_wave_float_ch0, -32767, 32767).astype(np.int16)
    
    # --- Load Hardware ---
    ol.dac0_bram[:] = wave_i16_ch0
    ol.dac1_bram[:] = wave_i16_ch1 
    
    # --- B. Fire and Capture ---
    ol.start_dacs()
    time.sleep(0.05) 
    
    if hasattr(ol, 'trigger_capture'):
        ol.trigger_capture()
    
    # We are measuring the system response via Ch1 (Master)
    raw_adc = np.array(ol.adc_capture_chC)
    
    # --- C. Handle Buffer Size Mismatches ---
    calc_len = min(len(target_wave), len(raw_adc))
    t_slice = target_wave[:calc_len]
    m_slice = raw_adc[:calc_len]
    
    # --- D. Calculate Error and Update ---
    # This error is derived from Ch1's physical output
    error_signal, m_aligned = get_aligned_error(t_slice, m_slice)
    
    # Capture the very first iteration for comparison later
    if i == 0:
        initial_measured = m_aligned.copy()

    # Calculate MSE
    mse = np.mean(error_signal**2)
    max_err = np.max(np.abs(error_signal))
    print(f"Iteration {i+1}: MSE = {mse:.2f} | Max Error = {max_err:.2f}")
    
    # --- E. Leaky Update (The Fix with Split Ratios) ---
    
    LEAK_FACTOR = 0.90 
    LEARNING_RATE = 0.25 

    # 1. Apply Leak (Both channels leak slightly to prevent drift)
    dac_wave_float_ch0[:calc_len] *= LEAK_FACTOR
    dac_wave_float_ch1[:calc_len] *= LEAK_FACTOR
    
    # 2. Calculate Corrections
    correction_ch1 = LEARNING_RATE * error_signal
    correction_ch0 = correction_ch1 * channel_0_corectionFactor
    
    # 3. Apply Corrections
    dac_wave_float_ch0[:calc_len] += correction_ch0
    dac_wave_float_ch1[:calc_len] += correction_ch1
    
    # --- Cyclic Fill for buffer mismatches ---
    if N > calc_len:
         remaining = N - calc_len
         # Fill Ch0
         dac_wave_float_ch0[calc_len:] = dac_wave_float_ch0[:remaining]
         # Fill Ch1
         dac_wave_float_ch1[calc_len:] = dac_wave_float_ch1[:remaining]

In [None]:
# --- Process Final Ch0 ---
p_limit_ch0 = np.percentile(np.abs(dac_wave_float_ch0), PERCENTILE_THRESHOLD)
dac_wave_float_ch0 = np.clip(dac_wave_float_ch0, -p_limit_ch0, p_limit_ch0)

current_max_ch0 = np.max(np.abs(dac_wave_float_ch0))
if current_max_ch0 > 1e-8:
    dac_wave_float_ch0 = dac_wave_float_ch0 * (DAC_MAX / current_max_ch0)

final_wave_i16_ch0 = np.clip(dac_wave_float_ch0, -32767, 32767).astype(np.int16)

# --- Process Final Ch1 ---
p_limit_ch1 = np.percentile(np.abs(dac_wave_float_ch1), PERCENTILE_THRESHOLD)
dac_wave_float_ch1 = np.clip(dac_wave_float_ch1, -p_limit_ch1, p_limit_ch1)

current_max_ch1 = np.max(np.abs(dac_wave_float_ch1))
if current_max_ch1 > 1e-8:
    dac_wave_float_ch1 = dac_wave_float_ch1 * (DAC_MAX / current_max_ch1)

final_wave_i16_ch1 = np.clip(dac_wave_float_ch1, -32767, 32767).astype(np.int16)

# --- Load and Start ---
ol.dac0_bram[:] = final_wave_i16_ch0
ol.dac1_bram[:] = final_wave_i16_ch1
ol.start_dacs()

# --- Helper to Normalize a Signal to +/- 1.0 (for plotting only) ---
def norm_one(sig):
    mx = np.max(np.abs(sig))
    return sig / mx if mx != 0 else sig

Starting Pre-distortion for 15 iterations...
Iteration 1: MSE = 263819175.10 | Max Error = 43506.65
Iteration 2: MSE = 158570947.24 | Max Error = 47234.30
Iteration 3: MSE = 86029177.76 | Max Error = 50323.83
Iteration 4: MSE = 47031720.04 | Max Error = 53442.06
Iteration 5: MSE = 25332150.55 | Max Error = 55809.53
Iteration 6: MSE = 20690047.84 | Max Error = 116424.11


In [25]:
ol.dac0_bram[:] = final_wave_i16_ch0*0
ol.dac1_bram[:] = final_wave_i16_ch0*0
ol.start_dacs()