In [None]:
import numpy as np
from sklearn.neighbors import KNeighborsRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel, ConstantKernel
from sklearn.metrics import r2_score, mean_squared_error
import matplotlib.pyplot as plt
plt.ion()  # Enable interactive mode for live plotting
import time
import os
import shutil
import pyvisa
import csv

class Info:
    icase = 6
    icmd = 1
    a0 = -1.5    # Changed from -20.
    a1 = -0.5    # Changed from 20.
    ndim = 1  # Dimensionality set to 1
    ntrain0 = 36   # Keep 36 initial training points
    nn1 = 100      # Changed from 400 to 100 evaluation points
    tolsig = 1e-3
    models12 = False
    # Instrument handles initialized to None
    rm = None
    yokogawa = None
    multimeter = None
    fname_accum = None

class GP():
    def __init__(self, ndim):
        self.ndim = ndim
        self.kernel = ConstantKernel(1.0, constant_value_bounds="fixed") * \
            RBF(1.0, length_scale_bounds="fixed")
        self.gp = GaussianProcessRegressor(
            kernel=self.kernel, n_restarts_optimizer=0)

    def train0(self):
        ntrain1 = 0
        if os.path.isfile(Info.fname_accum):
            with open(Info.fname_accum, 'r') as afile:
                ntrain1 = sum(1 for line in afile)
            X = np.zeros((ntrain1, self.ndim))
            y = np.zeros(ntrain1)
            with open(Info.fname_accum, 'r') as afile:
                for i, line in enumerate(afile):
                    parts = line.split()
                    X[i, 0] = float(parts[0]) # Modified for 1D
                    y[i] = float(parts[-1])
            self.gp.fit(X, y)

    def predict(self, X_test):
        y_pred, y_std = self.gp.predict(X_test, return_std=True)
        if X_test.shape[0] == 1:
            y_pred = y_pred[0]
            y_std = y_std[0]
        return y_pred, y_std

    def tpredict(self, X_test):
        self.train0()
        return self.predict(X_test)

class Multi_knn():
    def __init__(self, ndim, models12):
        self.ndim = ndim
        self.models12 = models12
        self.knns = [KNeighborsRegressor(n_neighbors=k, weights='distance') for k in range(4, 10)]
        if models12:
            self.knns.extend([KNeighborsRegressor(n_neighbors=k, p=1.5, weights='distance') for k in range(4, 10)])

    def train0(self):
        ntrain1 = 0
        if os.path.isfile(Info.fname_accum):
            with open(Info.fname_accum, 'r') as afile:
                ntrain1 = sum(1 for line in afile)

            X = np.zeros((ntrain1, self.ndim))
            y = np.zeros(ntrain1)
            with open(Info.fname_accum, 'r') as afile:
                for i, line in enumerate(afile):
                    parts = line.split()
                    X[i, 0] = float(parts[0]) # Modified for 1D
                    y[i] = float(parts[-1])

            for knn in self.knns:
                knn.fit(X, y)

    def predict(self, X_test):
        predictions = [knn.predict(X_test)[0] for knn in self.knns]
        return np.mean(predictions), np.std(predictions)

    def tpredict(self, X_test):
        self.train0()
        return self.predict(X_test)

def append_new_line(file_name, text_to_append):
    with open(file_name, "a+") as file_object:
        file_object.seek(0)
        data = file_object.read(100)
        if len(data) > 0:
            file_object.write("\n")
        file_object.write(text_to_append)

def afun(x):
    """Measure current at voltage x"""
    try:
        # Set voltage using Yokogawa command format from working code
        Info.yokogawa.write(f'S{x:0.5f}E')
        time.sleep(0.3)  # Match the delay from working code
        
        # Clear and read from multimeter
        Info.multimeter.write("*CLS")
        time.sleep(0.1)
        
        # Read measurement
        tmp = float(Info.multimeter.query("READ?"))
        return tmp
    except Exception as e:
        print(f"Error in measurement at V={x}: {e}")
        raise

def list_visa_resources():
    """List all available VISA resources"""
    rm = pyvisa.ResourceManager()
    print("\nAvailable VISA resources:")
    resources = rm.list_resources()
    if not resources:
        print("No VISA resources found!")
    else:
        for res in resources:
            print(f"  - {res}")
    rm.close()
    return resources

def test_instrument_commands():
    """Test basic communication with instruments"""
    print("\n=== Testing Instrument Communication ===")
    
    # Test Yokogawa
    print("\nTesting Yokogawa commands:")
    test_commands = ['F1R5O1E', 'S0.0E', 'S10.0E', 'O0E']  # Commands from working code
    for cmd in test_commands:
        try:
            Info.yokogawa.write(cmd)
            print(f"  {cmd} -> OK")
            time.sleep(0.2)
        except Exception as e:
            print(f"  {cmd} -> Failed: {e}")
    
    # Test DMM
    print("\nTesting DMM commands:")
    test_commands = ['*RST', '*CLS', 'READ?']
    for cmd in test_commands:
        try:
            if '?' in cmd:
                response = Info.multimeter.query(cmd)
                print(f"  {cmd} -> {response}")
            else:
                Info.multimeter.write(cmd)
                print(f"  {cmd} -> OK")
            time.sleep(0.2)
        except Exception as e:
            print(f"  {cmd} -> Failed: {e}")
    
    print("\n=== End of Communication Test ===\n")

def driver1(jter, all_data=None, live_fig=None, live_axes=None):
    try:
        Info.multimeter.write("*RST")
        time.sleep(1)  # Wait for reset
    except:
        print("Warning: Could not reset DMM, continuing...")
    
    print(f'{Info.ndim} dimensions')
    print(f'{Info.a0}, {Info.a1} boundary parameters')
    print(f'{Info.nn1}, {Info.tolsig} n1, tolsig')
    print(f'{Info.ntrain0} initial points')

    start = time.time()
    a0, a1, nn1, ndim, ntrain0, tolsig, models12 = Info.a0, Info.a1, Info.nn1, Info.ndim, Info.ntrain0, Info.tolsig, Info.models12
    ntrain = ntrain0

    if os.path.isfile(Info.fname_accum):
        gname = f"{Info.fname_accum[:-4]}{np.random.randint(1000000)}.txt"
        shutil.copyfile(Info.fname_accum, gname)
        os.remove(Info.fname_accum)

    X_train = np.linspace(a0, a1, ntrain0).reshape(-1, 1)
    y_train = np.zeros(ntrain0)

    for i in range(ntrain0):
        y_train[i] = afun(X_train[i, 0])
        astring = f"{X_train[i, 0]} {y_train[i]}"
        append_new_line(Info.fname_accum, astring)
        print(f"Point {i+1}/{ntrain0}: V={X_train[i, 0]:.3f}V, I={y_train[i]:.6f}")
    print('coarse', ntrain0)

    multi = Multi_knn(ndim, models12) if Info.icmd == 1 else GP(ndim)
    print(f'Using {"Multi_knn" if Info.icmd == 1 else "GP"}')
    multi.train0()
    print('train0: the first training')

    X_test = np.zeros((1, ndim))
    x_values_to_test = np.linspace(a0, a1, nn1)
    for ii, x_val in enumerate(x_values_to_test):
        X_test[0, 0] = x_val
        y_pred, y_std = multi.predict(X_test)

        if y_std >= tolsig:
            ntrain += 1
            tmp = afun(X_test[0, 0])
            astring = f"{X_test[0, 0]} {tmp}"
            append_new_line(Info.fname_accum, astring)
            multi.tpredict(X_test) # Retrain with the new point
            print(f"Added adaptive point {ntrain-ntrain0}: V={X_test[0, 0]:.3f}V")

    print(f'{ntrain} total points, ratio: {float(ntrain)/float(nn1):.2f}')

    b = np.linspace(a0, a1, nn1)
    nu = np.zeros(nn1)
    st = np.zeros(nn1)

    for i in range(nn1):
        X_test[0, 0] = b[i]
        y_pred, y_std = multi.predict(X_test)
        nu[i] = y_pred
        st[i] = y_std

    end = time.time()
    print(f"Total time: {end - start:.2f} seconds")

    gname = f"{Info.fname_accum[:-4]}.npy"
    with open(gname, 'wb') as afile:
        np.save(afile, nu)
        np.save(afile, st)

    # Store data for accumulation
    if all_data is not None:
        all_data['voltages'] = b
        all_data['currents'].append(nu)
        all_data['stds'].append(st)
        
        # Update live plots
        if live_fig is not None and live_axes is not None:
            ax1, ax2, ax3 = live_axes
            
            # Clear previous plots
            ax1.clear()
            ax2.clear()
            ax3.clear()
            
            # Plot 1: All iterations overlaid
            for idx, current in enumerate(all_data['currents']):
                ax1.plot(all_data['voltages'], current, alpha=0.5, label=f'Iter {idx}' if idx < 5 else "")
            ax1.set_xlabel('Voltage (V)')
            ax1.set_ylabel('Current (A)')
            ax1.set_title('All Iterations Overlaid')
            ax1.grid(True)
            if len(all_data['currents']) <= 5:
                ax1.legend()
            
            # Plot 2: Mean and standard deviation
            if len(all_data['currents']) > 1:
                mean_current = np.mean(all_data['currents'], axis=0)
                std_current = np.std(all_data['currents'], axis=0)
                
                ax2.plot(all_data['voltages'], mean_current, 'b-', linewidth=2, label='Mean')
                ax2.fill_between(all_data['voltages'], 
                                mean_current - std_current, 
                                mean_current + std_current, 
                                alpha=0.3, label='±1 STD')
                ax2.set_xlabel('Voltage (V)')
                ax2.set_ylabel('Current (A)')
                ax2.set_title(f'Mean Current (n={len(all_data["currents"])} iterations)')
                ax2.grid(True)
                ax2.legend()
            
            # Plot 3: Prediction uncertainty (latest iteration)
            ax3.plot(all_data['voltages'], all_data['stds'][-1], 'r-')
            ax3.set_xlabel('Voltage (V)')
            ax3.set_ylabel('Prediction Std. Dev. (A)')
            ax3.set_title(f'Prediction Uncertainty - Iteration {jter}')
            ax3.grid(True)
            
            plt.tight_layout()
            plt.pause(0.1)  # Update display

    # Save individual iteration plots
    plt.figure(figsize=(10, 5))
    plt.plot(b, nu, label='Mean Current')
    plt.xlabel('$V_1$ (V)', fontsize=14)
    plt.ylabel('Current (A)', fontsize=14)
    plt.title(f'Mean Current vs. Voltage - Iteration {jter}', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.savefig(f'mean_{jter}.pdf')
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(b, st, label='Std. Dev. in Current', color='orange')
    plt.xlabel('$V_1$ (V)', fontsize=14)
    plt.ylabel('Std. Dev. (A)', fontsize=14)
    plt.title(f'Prediction Std. Dev. vs. Voltage - Iteration {jter}', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.savefig(f'std_{jter}.pdf')
    plt.close()

if __name__ == '__main__':
    # Number of iterations to run
    num_iterations = 20  # Change this to control number of iterations
    
    try:
        # --- First, list available resources ---
        available_resources = list_visa_resources()
        
        # --- Instrument Initialization ---
        # GPIB addresses from your working code
        YOKOGAWA_ADDR = 'GPIB0::11::INSTR'  # Updated from working code
        DMM_ADDR = 'GPIB0::22::INSTR'       # Updated from working code
        
        # Alternative USB/Serial addresses if using USB-GPIB adapter
        # YOKOGAWA_ADDR = 'USB0::0x0957::0x0407::MY12345678::INSTR'
        # DMM_ADDR = 'ASRL3::INSTR'  # For serial connection
        
        print(f"\nAttempting to connect to:")
        print(f"  Yokogawa: {YOKOGAWA_ADDR}")
        print(f"  DMM: {DMM_ADDR}")

        # Set up the VISA resource manager with increased timeout
        Info.rm = pyvisa.ResourceManager()
        
        # Open connection to the Yokogawa voltage source
        print("\nConnecting to Yokogawa...")
        Info.yokogawa = Info.rm.open_resource(YOKOGAWA_ADDR)
        Info.yokogawa.timeout = 5000  # 5 second timeout
        Info.yokogawa.write_termination = '\n'
        Info.yokogawa.read_termination = '\n'
        
        # Initialize Yokogawa with command from working code
        try:
            Info.yokogawa.write("F1R5O1E")  # Function, Range, Output On
            time.sleep(0.5)
            print("Yokogawa initialized successfully")
        except Exception as e:
            print(f"Warning during Yokogawa initialization: {e}")
            print("Trying alternative initialization...")
            try:
                Info.yokogawa.write("F1E")  # Just function
                time.sleep(0.5)
                print("Yokogawa initialized with basic command")
            except:
                print("Warning: Could not initialize Yokogawa, proceeding anyway...")

        # Open connection to the DMM
        print("\nConnecting to DMM...")
        Info.multimeter = Info.rm.open_resource(DMM_ADDR)
        Info.multimeter.timeout = 5000  # 5 second timeout
        Info.multimeter.write_termination = '\n'
        Info.multimeter.read_termination = '\n'
        
        try:
            # Try to verify connection with *IDN?
            try:
                idn_dmm = Info.multimeter.query('*IDN?')
                print(f"Connected to DMM: {idn_dmm}")
            except:
                print("DMM doesn't respond to *IDN?, but connection established")
            
            # Reset DMM
            try:
                Info.multimeter.write('*RST')
                time.sleep(1)
                print("DMM reset successfully")
            except:
                print("Warning: Could not reset DMM")
                
        except Exception as e:
            print(f"Warning during DMM initialization: {e}")
            print("Continuing with basic connection...")

        # --- Run iterations ---
        print(f"\nStarting {num_iterations} measurement iterations...")
        
        # Initialize data storage for all iterations
        all_data = {
            'voltages': None,
            'currents': [],
            'stds': []
        }
        
        # Create live plotting figure
        live_fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
        live_axes = (ax1, ax2, ax3)
        plt.suptitle('Live Data Accumulation', fontsize=16)
        
        for jter in range(num_iterations):
            print(f"\n{'='*50}")
            print(f"Starting iteration {jter + 1} of {num_iterations}")
            print(f"{'='*50}")
            
            # Set unique filename for each iteration
            Info.fname_accum = f'iv_data_1D_iter_{jter}.txt'
            
            # Run the measurement with live plotting
            driver1(jter=jter, all_data=all_data, live_fig=live_fig, live_axes=live_axes)
            
            # Optional: Add delay between iterations
            if jter < num_iterations - 1:  # Don't delay after last iteration
                print(f"\nIteration {jter + 1} complete. Waiting 2 seconds before next iteration...")
                time.sleep(2)
        
        print(f"\nAll {num_iterations} iterations completed successfully!")
        
        # Save final accumulated data
        print("\nSaving accumulated data...")
        np.savez('accumulated_data.npz', 
                 voltages=all_data['voltages'],
                 currents=np.array(all_data['currents']),
                 stds=np.array(all_data['stds']))
        
        # Create final summary plots
        print("Creating final summary plots...")
        plt.figure(figsize=(15, 5))
        
        # Plot 1: All iterations with mean
        plt.subplot(1, 3, 1)
        for idx, current in enumerate(all_data['currents']):
            plt.plot(all_data['voltages'], current, alpha=0.3, color='gray')
        mean_current = np.mean(all_data['currents'], axis=0)
        plt.plot(all_data['voltages'], mean_current, 'b-', linewidth=3, label='Mean')
        plt.xlabel('Voltage (V)')
        plt.ylabel('Current (A)')
        plt.title('All Iterations with Mean')
        plt.grid(True)
        plt.legend()
        
        # Plot 2: Mean with confidence interval
        plt.subplot(1, 3, 2)
        std_current = np.std(all_data['currents'], axis=0)
        sem_current = std_current / np.sqrt(num_iterations)  # Standard error
        plt.plot(all_data['voltages'], mean_current, 'b-', linewidth=2)
        plt.fill_between(all_data['voltages'], 
                        mean_current - 2*sem_current, 
                        mean_current + 2*sem_current, 
                        alpha=0.3, label='95% CI')
        plt.xlabel('Voltage (V)')
        plt.ylabel('Current (A)')
        plt.title(f'Mean Current with 95% CI (n={num_iterations})')
        plt.grid(True)
        plt.legend()
        
        # Plot 3: Iteration-to-iteration variability
        plt.subplot(1, 3, 3)
        plt.plot(all_data['voltages'], std_current, 'r-', linewidth=2)
        plt.xlabel('Voltage (V)')
        plt.ylabel('Std. Dev. (A)')
        plt.title('Iteration-to-Iteration Variability')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig('final_summary.pdf')
        plt.show()
        
        print("\nFinal plots saved as 'final_summary.pdf'")
        print("Accumulated data saved as 'accumulated_data.npz'")
        
        # Keep the live plot window open
        plt.show(block=True)

    except pyvisa.errors.VisaIOError as e:
        print(f"\nVISA Error: {e}")
        print("\nTroubleshooting steps:")
        print("1. Check that instruments are powered on")
        print("2. Verify GPIB cable connections")
        print("3. Confirm GPIB addresses (use NI MAX or similar tool)")
        print("4. Try different address formats:")
        print("   - GPIB0::1::INSTR or GPIB::1::INSTR")
        print("   - For USB-GPIB: USB0::0x####::0x####::SERIAL::INSTR")
        print("5. Ensure VISA drivers are installed (NI-VISA or similar)")
        print("6. Check if another program is using the instruments")
        
    except Exception as e:
        print(f"\nUnexpected error: {e}")
        import traceback
        traceback.print_exc()
        
    finally:
        # --- Close Connections ---
        print("\nClosing connections...")
        if hasattr(Info, 'yokogawa') and Info.yokogawa:
            try:
                Info.yokogawa.close()
                print("Yokogawa connection closed")
            except:
                pass
        if hasattr(Info, 'multimeter') and Info.multimeter:
            try:
                Info.multimeter.close()
                print("DMM connection closed")
            except:
                pass
        if hasattr(Info, 'rm') and Info.rm:
            try:
                Info.rm.close()
                print("Resource manager closed")
            except:
                pass