# **DATA GENERATION**

In [12]:
import numpy as np
import pandas as pd

def generate_fruit_dataset(
    n_orange=70,
    n_lemon=70,
    n_apple=70,
    random_seed=42,
    output_file="fruit_dataset_70.csv"
):
    np.random.seed(random_seed)

    # --- Define feature distributions ---
    # Oranges: heavier, larger, higher pH
    orange_weight   = np.random.normal(loc=150, scale=10,  size=n_orange)  # grams
    orange_diameter = np.random.normal(loc=8.0, scale=0.5, size=n_orange)  # cm
    orange_ph       = np.random.normal(loc=3.8, scale=0.2, size=n_orange)  # pH

    # Lemons: lighter, smaller, lower pH
    lemon_weight   = np.random.normal(loc=55,  scale=5,   size=n_lemon)      # grams
    lemon_diameter = np.random.normal(loc=4.5, scale=0.3, size=n_lemon)      # cm
    lemon_ph       = np.random.normal(loc=2.3, scale=0.2, size=n_lemon)      # pH

    # Apples: medium weight, medium diameter, medium pH
    apple_weight   = np.random.normal(loc=120, scale=8,   size=n_apple)    # grams
    apple_diameter = np.random.normal(loc=7.0, scale=0.4, size=n_apple)    # cm
    apple_ph       = np.random.normal(loc=3.0, scale=0.2, size=n_apple)    # pH

    # --- Create DataFrames for each class ---
    oranges_df = pd.DataFrame({
        "weight":   orange_weight,
        "diameter": orange_diameter,
        "ph_level": orange_ph,
        "label":    0  # Class 0 = Orange
    })

    lemons_df = pd.DataFrame({
        "weight":   lemon_weight,
        "diameter": lemon_diameter,
        "ph_level": lemon_ph,
        "label":    1  # Class 1 = Lemon
    })

    apples_df = pd.DataFrame({
        "weight":   apple_weight,
        "diameter": apple_diameter,
        "ph_level": apple_ph,
        "label":    2  # Class 2 = Apple
    })

    # --- Combine and shuffle dataset ---
    df = pd.concat([oranges_df, lemons_df, apples_df], ignore_index=True)

    # Shuffle rows
    df = df.sample(frac=1, random_state=random_seed).reset_index(drop=True)

    # --- Save to CSV ---
    df.to_csv(output_file, index=False)

    # --- Show first 5 rows ---
    print("First 5 rows of the dataset:")
    print(df.head())

    return df

if __name__ == "__main__":
    generate_fruit_dataset()

First 5 rows of the dataset:
       weight  diameter  ph_level  label
0  143.982934  7.292315  3.622097      0
1  112.114192  6.734950  3.095796      2
2   52.640341  4.607105  2.438029      1
3  128.956599  6.758406  3.044777      2
4  145.208258  7.224668  3.871557      0


In [13]:
df = generate_fruit_dataset()
print("DataFrame 'df' has been initialized with the 3-class dataset and is ready for use.")

First 5 rows of the dataset:
       weight  diameter  ph_level  label
0  143.982934  7.292315  3.622097      0
1  112.114192  6.734950  3.095796      2
2   52.640341  4.607105  2.438029      1
3  128.956599  6.758406  3.044777      2
4  145.208258  7.224668  3.871557      0
DataFrame 'df' has been initialized with the 3-class dataset and is ready for use.


In [14]:
numerical_cols = df.select_dtypes(include=np.number).columns.tolist()
if 'label' in numerical_cols:
    numerical_cols.remove('label')
print("Numerical columns identified for transformation:", numerical_cols)

Numerical columns identified for transformation: ['weight', 'diameter', 'ph_level']


In [15]:
import numpy as np

for col in numerical_cols:
    min_val = df[col].min()
    max_val = df[col].max()

    # Handle case where min_val equals max_val to avoid division by zero
    if (max_val - min_val) == 0:
        df[col] = 0 # Or any other appropriate value, e.g., np.pi / 2
    else:
        df[col] = (df[col] - min_val) * np.pi / (max_val - min_val)

print("DataFrame after min-max scaling to [0, pi] for numerical columns:")
print(df.head())

DataFrame after min-max scaling to [0, pi] for numerical columns:
     weight  diameter  ph_level  label
0  2.547322  1.995791  2.036108      0
1  1.775571  1.666483  1.422667      2
2  0.335319  0.409287  0.655993      1
3  2.183436  1.680341  1.363200      2
4  2.576995  1.955823  2.326872      0


In [16]:
df

Unnamed: 0,weight,diameter,ph_level,label
0,2.547322,1.995791,2.036108,0
1,1.775571,1.666483,1.422667,2
2,0.335319,0.409287,0.655993,1
3,2.183436,1.680341,1.363200,2
4,2.576995,1.955823,2.326872,0
...,...,...,...,...
205,0.334954,0.466882,0.574778,1
206,2.275319,2.175072,2.311786,0
207,0.397980,0.478463,0.833392,1
208,1.734470,1.609363,1.398964,2


# **LIBRARY INSTALLATION AND IMPORTING** (QISKIT)s




In [17]:
!pip install qiskit
!pip install qiskit-aer
!pip install pylatexenc
!pip install ipywidgets



In [18]:
from qiskit.circuit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit_aer.primitives import Sampler
from qiskit.visualization import plot_histogram
from qiskit.visualization import plot_bloch_vector
from qiskit.quantum_info import Statevector
from qiskit.visualization import plot_bloch_multivector
from qiskit.visualization import circuit_drawer

from IPython.display import Image, display



# **OTHER LIBRARIES**

In [19]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

# colors eieiei
RED = '\033[91m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
END = '\033[0m'
import json

# **FITNESS CREATION**

In [20]:
def create_fitness(n_controls, n_targets, del_phi_set):
    """
    สร้าง Unitary Gate สำหรับ n_controls และ n_targets
    โดยใช้มุมจาก del_phi_set (ซึ่งเป็น List of lists)
    """

    if len(del_phi_set) != n_targets:
        raise ValueError(f"Error: del_phi_set มี {len(del_phi_set)} แถว แต่คุณกำหนด n_targets = {n_targets}")

    # 1. สร้างวงจรขนาดรวม (Controls + Targets)
    total_qubits = n_controls + n_targets
    qc_temp = QuantumCircuit(total_qubits)

    # 2. วนลูป Loop นอก: จัดการ Target ทีละตัว
    for t in range(n_targets):

        # ระบุตำแหน่งของ Target Qubit ในวงจร (ต่อท้าย Control ตัวสุดท้าย)
        # เช่น ถ้ามี 3 Controls -> Target แรกคือ index 3, ถัดไปคือ 4...
        target_qubit_index = n_controls + t

        # ดึง del_phi_list เฉพาะของ Target ตัวนี้ออกมาจาก Set
        # del_phi_set[0] สำหรับ target แรก, del_phi_set[1] สำหรับ target สอง...
        current_phi_list = del_phi_set[t]

        # ตรวจสอบว่าจำนวนมุมพอดีกับ control ไหม
        if len(current_phi_list) != n_controls:
             raise ValueError(f"Error: ที่ Target {t} มีมุม {len(current_phi_list)} ตัว แต่มี Control {n_controls} ตัว")

        # 3. วนลูป Loop ใน: ผูก Control แต่ละตัว เข้ากับ Target ตัวปัจจุบัน
        for c in range(n_controls):
            angle = current_phi_list[c]

            # ใช้ Logic เดิมของคุณ:
            # ถ้า Control เป็น 1 -> หมุน +phi
            qc_temp.crz(angle, c, target_qubit_index, ctrl_state=1)

            # ถ้า Control เป็น 0 -> หมุน -phi
            qc_temp.crz(-angle, c, target_qubit_index, ctrl_state=0)

    # 4. แปลงวงจรทั้งหมดเป็น Gate ก้อนเดียว
    custom_gate = qc_temp.to_gate(label=f"Entangle_{n_targets}Targets")

    return custom_gate

def create_delphi_list(m_input,w_list):
  del_phi_list = []
  if m_input != len(w_list):
    raise AssertionError("m_input must be equal to the length of w_list")

  else:
    for i in range(m_input):
      del_phi_list.append(w_list[i]*np.pi/2)
  return del_phi_list

def create_delphi_set(m_input, w_set):

    del_phi_set = []

    for w_list in w_set:
        phi_list = create_delphi_list(m_input, w_list)
        del_phi_set.append(phi_list)

    return del_phi_set


In [None]:
import numpy as np
import json
import os

def save_delphi_set(del_phi_set, base_filename="del_phi_set"):
    """
    Save del_phi_set to two files:
    1. A readable .txt file   (JSON format)
    2. A binary .npy file     (NumPy format)

    del_phi_set must be a list of lists of floats.
    """

    # 1. Save as JSON-readable TXT
    txt_filename = f"{base_filename}.txt"
    with open(txt_filename, "w") as f:
        f.write(json.dumps(del_phi_set, indent=4))
    print(f"[✔] Saved del_phi_set to {txt_filename}")

    # 2. Save as NumPy binary
    npy_filename = f"{base_filename}.npy"
    np.save(npy_filename, np.array(del_phi_set, dtype=float))
    print(f"[✔] Saved del_phi_set to {npy_filename}")

    return txt_filename, npy_filename

# Example usage:
# del_phi_set = create_delphi_set(m_input, w_set)
# save_delphi_set(del_phi_set, "my_saved_delphi_set")

In [21]:
# --- 1. เตรียม Gate และ Parameters (ทำนอก Loop เพื่อความเร็ว) ---

n_controls = 3
n_targets = 2
w_set = [[0.3,0.35,0.35],[0.6,0.15,0.25]]

sampler = Sampler()

# สร้าง del_phi_set จาก w_set ที่คุณมีอยู่แล้ว
# (สมมติว่า w_set คือ List of Lists ขนาด 2x3 สำหรับ 2 targets, 3 features)
del_phi_set = create_delphi_set(n_controls, w_set)

# สร้าง Gate ก้อนใหญ่เตรียมไว้
my_gate = create_fitness(n_controls, n_targets, del_phi_set)

# เตรียมตัวแปรเก็บผลลัพธ์ (Confusion Matrix แบบง่าย)
# รูปแบบ: { Class_ของจริง : { ทำนายได้0: จำนวน, ทำนายได้1: จำนวน, ... } }
results_log = {
    0: {0: 0, 1: 0, 2: 0, 3: 0},
    1: {0: 0, 1: 0, 2: 0, 3: 0},
    2: {0: 0, 1: 0, 2: 0, 3: 0}
}

# --- 2. เริ่มวนลูปทำนายผล ---

# ใช้ len(df) หรือจำนวนที่คุณต้องการทดสอบ
for i in tqdm(range(len(df)), desc="Processing 3 Classes"):

    # --- สร้างวงจร (Circuit Construction) ---
    # ใช้ 5 Qubits (3 Controls + 2 Targets) และ 2 Classical Bits
    qc = QuantumCircuit(n_controls + n_targets, n_targets)

    # ใส่ข้อมูล (Encoding)
    theta = [df.weight[i], df.diameter[i], df.ph_level[i]]

    # หมุน Controls
    qc.ry(theta[0], 0)
    qc.ry(theta[1], 1)
    qc.ry(theta[2], 2)

    # เตรียม Targets (ตาม Code ใหม่ของคุณใช้ rx -pi/2)
    qc.rx(-np.pi/2, 3)
    qc.rx(-np.pi/2, 4)

    # ใส่ Gate ที่เราสร้างไว้
    qc.append(my_gate, range(n_controls + n_targets))

    # ปรับ Basis ก่อนวัด
    qc.h(3)
    qc.h(4)

    # วัดผล (Target 1 -> bit 0, Target 2 -> bit 1)
    qc.measure(3, 0)
    qc.measure(4, 1)

    # --- รันและเก็บผล (Execution & Logging) ---
    job = sampler.run(qc, shots=1)
    result = job.result()
    dist = result.quasi_dists[0]

    # ดึงค่าที่วัดได้ (เนื่องจาก shots=1 จะมี key เดียว)
    # ผลลัพธ์จะเป็น integer: 0('00'), 1('01'), 2('10'), 3('11')
    prediction = list(dist.keys())[0]

    # ดึงเฉลยจริงจาก DataFrame
    actual_label = df.label[i]

    # บันทึกผลลง Log
    # ตรวจสอบก่อนว่า label ใน df เป็น 0,1,2 หรือไม่ เพื่อป้องกัน error
    if actual_label in results_log:
        results_log[actual_label][prediction] += 1

# --- 3. แสดงผลลัพธ์ ---

print("\n--- Evaluation Results ---")
# Mapping ปกติ: 0->Class0, 1->Class1, 2->Class2 (3->Unused/Error)
for actual_class in results_log:
    print(f"Actual Class {actual_class}:")
    print(f"  Predicted '00' (0): {results_log[actual_class][0]}")
    print(f"  Predicted '01' (1): {results_log[actual_class][1]}")
    print(f"  Predicted '10' (2): {results_log[actual_class][2]}")
    print(f"  Predicted '11' (3): {results_log[actual_class][3]}") # State นี้มักจะไม่ค่อยมีข้อมูลถ้าเทรนมาดี
    print("-" * 20)

Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]


--- Evaluation Results ---
Actual Class 0:
  Predicted '00' (0): 2
  Predicted '01' (1): 5
  Predicted '10' (2): 8
  Predicted '11' (3): 55
--------------------
Actual Class 1:
  Predicted '00' (0): 61
  Predicted '01' (1): 3
  Predicted '10' (2): 4
  Predicted '11' (3): 2
--------------------
Actual Class 2:
  Predicted '00' (0): 19
  Predicted '01' (1): 6
  Predicted '10' (2): 7
  Predicted '11' (3): 38
--------------------


# **CODENAME**

In [22]:
def generate_w_set_codename(w_set_list):
    """
    Generates a unique string 'code name' for a w_set list.
    """
    flat_list = []
    for sublist in w_set_list:
        flat_list.extend(sublist)
    # Convert numbers to strings and join them, replacing '.' with 'p' for file-name like string
    codename = "w_" + "_".join([str(x).replace('.', 'p') for x in flat_list])
    return codename

In [23]:
def get_w_from_codename(codename):

    data_string = codename[2:]
    string_values = data_string.split('_')
    w_values = []
    for s_val in string_values:
        try:
            w_values.append(float(s_val.replace('p', '.')))
        except ValueError:
            raise ValueError(f"Invalid codename: Could not convert '{s_val}' to a number.")

    if len(w_values) != 6:
        raise ValueError("Invalid codename: Expected 6 numerical values for w_set, but found "
                         f"{len(w_values)}. (Expected 2 sublists of 3 elements each)")

    w_set = [
        [w_values[0], w_values[1], w_values[2]],
        [w_values[3], w_values[4], w_values[5]]
    ]
    return w_set

# **DISPLAYING AND PLOTTING**

In [24]:
all_run_results = {}

In [25]:
def display_run_results(codename):
    if codename in all_run_results:
        results = all_run_results[codename]
        print(f"\n--- Evaluation Results for Run: {codename} ---")
        for actual_class in results:
            print(f"Actual Class {actual_class}:")
            print(f"  Predicted '00' (0): {results[actual_class][0]}")
            print(f"  Predicted '01' (1): {results[actual_class][1]}")
            print(f"  Predicted '10' (2): {results[actual_class][2]}")
            print(f"  Predicted '11' (3): {results[actual_class][3]}")
            print("-" * 20)
    else:
        print(f"No results found for codename: {codename}")

print("Function 'display_run_results' defined.")

Function 'display_run_results' defined.


In [26]:
def display_all_run_results():
  for codename, results in all_run_results.items():
      print(f"\n--- Evaluation Results for Run: {codename} ---")
      for actual_class in results:
          print(f"Actual Class {actual_class}:")
          print(f"  Predicted '00' (0): {results[actual_class][0]}")
          print(f"  Predicted '01' (1): {results[actual_class][1]}")
          print(f"  Predicted '10' (2): {results[actual_class][2]}")
          print(f"  Predicted '11' (3): {results[actual_class][3]}")
          print("-" * 20)

In [27]:
def plot_results_grouped_bar(results_log, codename):
    """
    Visualizes the classification results using a grouped bar chart.
    """
    states = ['00 (State 0)', '01 (State 1)', '10 (State 2)', '11 (State 3)']
    n_states = len(states)

    class_0_counts = [results_log[0][i] for i in range(4)]
    class_1_counts = [results_log[1][i] for i in range(4)]
    class_2_counts = [results_log[2][i] for i in range(4)]

    x = np.arange(n_states)
    width = 0.25

    fig, ax = plt.subplots(figsize=(12, 7))

    rects1 = ax.bar(x - width, class_0_counts, width, label='Actual Class 0', color='#1f77b4', alpha=0.8)
    rects2 = ax.bar(x,         class_1_counts, width, label='Actual Class 1', color='#ff7f0e', alpha=0.8)
    rects3 = ax.bar(x + width, class_2_counts, width, label='Actual Class 2', color='#2ca02c', alpha=0.8)

    ax.set_ylabel('Frequency (Count)')
    ax.set_title(f'Quantum Classification Results for: {codename}')
    ax.set_xticks(x)
    ax.set_xticklabels(states)
    ax.legend()
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)

    def autolabel(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate('{}'.format(height),
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=9)

    autolabel(rects1)
    autolabel(rects2)
    autolabel(rects3)

    plt.tight_layout()
    plt.show()

# The loop to plot all results has been removed from here to avoid automatic plotting.
# Use plot_specific_run_results(codename) to plot individual results or
# implement a separate loop if you wish to plot all at once.


In [28]:
def plot_specific_run_results(codename):
    if codename in all_run_results:
        results = all_run_results[codename]
        plot_results_grouped_bar(results, codename)
    else:
        print(f"No results found for codename: {codename}")

print("Function 'plot_specific_run_results' defined.")

Function 'plot_specific_run_results' defined.


# **run_weighed_sampler DEFINITION**

In [29]:
def run_weighed_sampler(w_set_input, data_df):
    global n_controls, n_targets, sampler, create_delphi_set, create_fitness

    # ... [Previous setup code remains exactly the same] ...

    del_phi_set = create_delphi_set(n_controls, w_set_input)
    my_gate = create_fitness(n_controls, n_targets, del_phi_set)

    results_log_func = {
        0: {0: 0, 1: 0, 2: 0, 3: 0},
        1: {0: 0, 1: 0, 2: 0, 3: 0},
        2: {0: 0, 1: 0, 2: 0, 3: 0}
    }

    # --- TQDM UPDATE HERE ---
    # leave=False: Removes the bar when done, so the next J_plus/J_minus bar can take its spot.
    for i in tqdm(range(len(data_df)), desc="Processing 3 Classes"):

        # ... [The rest of your quantum circuit code remains exactly the same] ...

        qc = QuantumCircuit(n_controls + n_targets, n_targets)
        theta = [data_df.weight[i], data_df.diameter[i], data_df.ph_level[i]]

        qc.ry(theta[0], 0)
        qc.ry(theta[1], 1)
        qc.ry(theta[2], 2)
        qc.rx(-np.pi/2, 3)
        qc.rx(-np.pi/2, 4)
        qc.append(my_gate, range(n_controls + n_targets))
        qc.h(3)
        qc.h(4)
        qc.measure(3, 0)
        qc.measure(4, 1)

        job = sampler.run(qc, shots=1)
        result = job.result()
        dist = result.quasi_dists[0]
        prediction = list(dist.keys())[0]
        actual_label = data_df.label[i]

        if actual_label in results_log_func:
            results_log_func[actual_label][prediction] += 1

    return results_log_func

# codename testing;

In [30]:
my_w_set = [[0.3, 0.35, 0.35], [0.6, 0.15, 0.25]]
test_codename = generate_w_set_codename(my_w_set )
print(f"Sample w_set: {my_w_set}")
print(f"Generated Code Name: {test_codename}")

Sample w_set: [[0.3, 0.35, 0.35], [0.6, 0.15, 0.25]]
Generated Code Name: w_0p3_0p35_0p35_0p6_0p15_0p25


In [31]:
gunw = [[0.3, 0.35, 0.35], [0.6, 0.15, 0.25]]
gunw_codename = generate_w_set_codename(gunw)
print(gunw_codename)
print(get_w_from_codename(gunw_codename))

w_0p3_0p35_0p35_0p6_0p15_0p25
[[0.3, 0.35, 0.35], [0.6, 0.15, 0.25]]


# **SPSA PREF**

In [32]:
import numpy as np

def flatten_w_set(w_set):
    w_np = np.array(w_set)
    w_shape = w_np.shape
    w_vector = w_np.flatten()
    return w_vector, w_shape

def unflatten_w_set(w_vector, w_shape):
    w_np = w_vector.reshape(w_shape)
    w_set = w_np.tolist()
    return w_set

'''def calculate_cost(w_set_input, data_df):
    results_log = run_weighed_sampler(w_set_input, data_df)

    total_samples = len(data_df)
    correct_predictions = 0

    correct_predictions += results_log[0][0]
    correct_predictions += results_log[1][1]
    correct_predictions += results_log[2][2]

    if total_samples > 0:
        accuracy = correct_predictions / total_samples
        cost = 1.0 - accuracy
    else:
        cost = 1.0

    return cost'''

'def calculate_cost(w_set_input, data_df):\n    results_log = run_weighed_sampler(w_set_input, data_df)\n\n    total_samples = len(data_df)\n    correct_predictions = 0\n\n    correct_predictions += results_log[0][0]\n    correct_predictions += results_log[1][1]\n    correct_predictions += results_log[2][2]\n\n    if total_samples > 0:\n        accuracy = correct_predictions / total_samples\n        cost = 1.0 - accuracy\n    else:\n        cost = 1.0\n\n    return cost'

In [33]:
def run_spsa_optimization(initial_w_set, K_ITERATIONS, a = 0, c = 0):


    best_w_set, final_min_cost, cost_history = spsa_optimization_revised(
          initial_w_set=initial_w_set,
          data_df=df,
          K=K_ITERATIONS
      )

    print("\n--- OPTIMIZATION RESULTS ---")
    print("Best W_SET found:")
    print(best_w_set)
    print("; codename:")
    print(generate_w_set_codename(best_w_set))
    print(f"Lowest Misclassification Cost achieved: {final_min_cost}")

    import json
    data_to_dump = {
      'final_cost': final_min_cost,
      'weight': best_w_set,
      'w codename': generate_w_set_codename(best_w_set)
    }

    formatted_json_str = (
      json.dumps(data_to_dump, separators=(',', ':'))
      .replace(',"w codename"', ',\n  "w codename"')
      .replace(',"weight"', ',\n  "weight"')
      .replace('{', '{\n  ').replace('}', '\n}')
      .replace('{', '{\n  ').replace('}', '\n}')
    )


    with open('best_weights.json', 'a') as f:
      f.write(formatted_json_str)
      f.write('\n\n')

    print("results saved to file")



# **COST CALCULATION**

In [34]:
def calculate_cost(w_set_input, data_df):
    # Step 1: run sampler
    results_log = run_weighed_sampler(w_set_input, data_df)

    # Step 2: build class vectors (counts across states)
    class_keys = sorted(results_log.keys())
    vectors = []

    for cls in class_keys:
        v = np.array([results_log[cls].get(s, 0) for s in range(4)], dtype=float)
        total = v.sum()
        if total > 0:
            v = v / total  # normalize to probability vector
        vectors.append(v)

    # Step 3: build matrix P (rows = class vectors)
    P = np.vstack(vectors)   # shape (num_classes, 4)

    # Step 4: Gram matrix G = P * P^T
    G = P @ P.T

    # Step 5: determinant of Gram matrix
    detG = np.linalg.det(G)

    # Step 6: convert determinant to cost
    # Larger det → more orthogonal → BETTER → lower cost
    cost = 1.0 - detG  # stays bounded [0,1] under normalization

    return float(cost)

# **SPSA MAIN**

In [35]:
def spsa_optimization_revised(initial_w_set, data_df, K=100, a=0.43, c=0.04, A=10, alpha=0.602, gamma=0.101):

    print(f"Starting SPSA optimization for {K} iterations...")

    w_vector, w_shape = flatten_w_set(initial_w_set)
    p = len(w_vector)

    best_w_vector = w_vector.copy()
    min_cost = float('inf')
    cost_history = []

    # Initialize previous weight vector to the start, or to the best_w_vector
    # This will hold the vector *before* the update that led to a bad cost.
    w_vector_before_update = w_vector.copy()

    # --- TQDM UPDATE HERE ---
    # position=0: Keeps this bar at the top.
    pbar = tqdm(range(K), desc="SPSA Iterations", position=0)
    for k in pbar:

        # --- Store the current w_vector before calculation of Delta_k and update ---
        w_vector_before_update = w_vector.copy()

        # --- SPSA Math remains exactly the same ---
        a_k = a / (k + 1 + A)**alpha
        c_k = c / (k + 1)**gamma

        # --- Roll-back Logic Loop ---
        # We will loop1 until the cost improves or we reach a maximum number of re-tries
        max_re_tries = 3 # Set a sensible limit to prevent infinite loops
        for re_try in range(max_re_tries):

            # 1. Perturbation
            Delta_k = np.random.choice([-1., 1.], size=p)

            # 2. Perturbed Weights
            w_plus  = w_vector_before_update + c_k * Delta_k # Use the stored vector
            w_minus = w_vector_before_update - c_k * Delta_k # Use the stored vector

            # 3. Calculate Cost
            J_plus = calculate_cost(unflatten_w_set(w_plus, w_shape), data_df)
            print(f'k{k}/r{re_try}[+] : {J_plus}, \n weight: {unflatten_w_set(w_plus, w_shape)}')
            J_minus = calculate_cost(unflatten_w_set(w_minus, w_shape), data_df)
            print(f'k{k}/r{re_try}[-] : {J_minus}, \n weight: {unflatten_w_set(w_minus, w_shape)}')

            # 4. Gradient Approximation
            g_hat = (J_plus - J_minus) / (2 * c_k * Delta_k + 1e-10)

            # 5. Update (Candidate w_vector)
            w_candidate = w_vector_before_update - a_k * g_hat
            w_candidate = np.clip(w_candidate, 0, 1)

            # 6. Evaluate Candidate Cost
            current_cost = calculate_cost(unflatten_w_set(w_candidate, w_shape), data_df)
            cost_history.append(current_cost)

            # 7. Roll-back Decision
            if current_cost < min_cost:
                # Success! Accept the update.
                w_vector = w_candidate.copy() # Update the main vector

                # --- Update Best Solution and Save ---
                min_cost = current_cost
                best_w_vector = w_vector.copy()
                print(f'{GREEN}k{k}[current] : {current_cost} (Accepted){END}, w')

                # Save results to file (using the same logic from original code)
                import json
                data_to_dump = {
                  'better cost': min_cost,
                  'weight': unflatten_w_set(best_w_vector, w_shape),
                  'w codename': generate_w_set_codename(unflatten_w_set(best_w_vector, w_shape))
                }

                formatted_json_str = (
                  json.dumps(data_to_dump, separators=(',', ':'))
                  .replace(',"w codename"', ',\n  "w codename"')
                  .replace(',"weight"', ',\n  "weight"')
                  .replace('{', '{\n  ').replace('}', '\n}')
                  .replace('{', '{\n  ').replace('}', '\n}')
                )
                with open('best_weights.json', 'a') as f:
                  f.write(formatted_json_str)
                  f.write('\n\n')
                print("results saved to file")

                break # Exit the re-try loop and move to the next main iteration (k+1)
            else:
                # Failure! Current cost is worse.
                # w_vector remains w_vector_before_update for the next re-try
                print(f'{RED}k{k}/r{re_try}[current] : {current_cost} (Rejected, Re-trying...){END}')

                if re_try == max_re_tries - 1:
                    # If all re-tries fail, still use the best known vector for the next main iteration
                    w_vector = w_vector_before_update.copy()
                    print(f'Max re-tries ({max_re_tries}) reached. Using w_vector from k={k-1}.')


    best_w_set = unflatten_w_set(best_w_vector, w_shape)
    final_min_cost = min_cost
    return best_w_set, min_cost, cost_history

# **OPTIMIZATION**

In [36]:
initial_w_set = [[0.40152766722492955,0.4440525783787271,0.4440525783787271],[0.3559474217352591,0.40152766722492955,0.3559474217352591]]
run_spsa_optimization(initial_w_set, 100, a = 0.5, c = 0.05)


Starting SPSA optimization for 100 iterations...


SPSA Iterations:   0%|          | 0/100 [00:00<?, ?it/s]

Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]

k0/r0[+] : 0.9956680368213925, 
 weight: [[0.41499673835294437, 0.3638265227353272, 0.426919926432194], [0.3424783505765991, 0.32130161158152964, 0.35617347734801386]]


Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]

k0/r0[-] : 0.9951002198403726, 
 weight: [[0.3349967383529444, 0.44382652273532713, 0.506919926432194], [0.42247835057659905, 0.4013016115815296, 0.4361734773480138]]


Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]

[92mk0[current] : 0.9944548035597413 (Accepted)[0m, w
results saved to file


Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]

k1/r0[+] : 0.9965420248025908, 
 weight: [[0.33698072171908966, 0.4418425393709833, 0.5049359430678501], [0.34590344829730857, 0.3247267093022391, 0.43418949398366996]]


Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]

k1/r0[-] : 0.9855443656639665, 
 weight: [[0.4115716406340363, 0.36725162045603665, 0.4303450241529035], [0.4204943672122552, 0.39931762821718575, 0.35959857506872334]]


Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]

[91mk1/r0[current] : 0.9954968286683269 (Rejected, Re-trying...)[0m


Processing 3 Classes:   0%|          | 0/210 [00:00<?, ?it/s]

KeyboardInterrupt: 