# Partial Optimization of Control Parameters

In this notebook, we demonstrate how to **optimize only a subset** of the control parameters (for example, just the Rabi amplitudes for \(\Omega_i\) and \(\Omega_r\)), while keeping other parameters (detuning \(\Delta_i(t)\) or phases) fixed. This can be useful if we want to test the effect of only adjusting certain parts of the pulse, or if we already know certain parameters should not change.


## Motivation

We have 5 columns in `ctrl_a`, `ctrl_b`, and `ctrl_c` (one column per control channel in the real code):
1. \(\Delta_i(t)\)  (detuning)
2. Rabi \(\Omega_i\) amplitude
3. Rabi \(\Omega_i\) phase
4. Rabi \(\Omega_r\) amplitude
5. Rabi \(\Omega_r\) phase

Sometimes we want to freeze (not optimize) certain parameters, for example keep detuning and phases the same, while only optimizing amplitude. We'll show how to do that by applying a mask to the gradients so that those columns remain constant.

## Imports and Setup

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from functools import partial

from quantum_optimal_control.two_qubit.propagator_vl_jax import (
    PropagatorVLJAX,
    single_optimization_step
)
from quantum_optimal_control.toolkits.plotting_helper import PlotPlotter, getStylishFigureAxes

key = jax.random.PRNGKey(123)


### Create a Custom Single Step Function with Mask

We can partially optimize by multiplying the gradient of the columns we do **not** want to change by 0. This effectively 'freezes' them at their initial values.

For instance, if we only want to optimize the columns for Rabi_i amplitude (column 1) and Rabi_r amplitude (column 3) among the 5 columns, we'd define a mask:
```
mask = [
   0.0,   # column 0 (Delta)
   1.0,   # column 1 (Rabi_i amplitude)
   0.0,   # column 2 (Rabi_i phase)
   1.0,   # column 3 (Rabi_r amplitude)
   0.0    # column 4 (Rabi_r phase)
]
```
Then when we do the gradient update, we do `gradA = gradA * mask` for each row, and similarly for `gradB` and `gradC`.


In [None]:
# We'll define a new single-step function that uses a gradient mask.

def single_optimization_step_masked(
    propagator: PropagatorVLJAX,
    ctrl_a: jnp.ndarray,
    ctrl_b: jnp.ndarray,
    ctrl_c: jnp.ndarray,
    mask_a: jnp.ndarray,
    mask_b: jnp.ndarray,
    mask_c: jnp.ndarray,
    lr: float = 0.02
):
    """
    Similar to single_optimization_step, but we apply a mask to the gradients.
    mask_* should have the same shape as ctrl_*.
    """
    def loss_fn(a, b, c):
        return propagator.target(a, b, c)

    cost, grads = jax.value_and_grad(loss_fn, argnums=(0,1,2))(ctrl_a, ctrl_b, ctrl_c)
    gradA, gradB, gradC = grads

    # Apply mask
    gradA_masked = gradA * mask_a
    gradB_masked = gradB * mask_b
    gradC_masked = gradC * mask_c

    # Update only the allowed columns
    new_ctrl_a = ctrl_a - lr * gradA_masked
    new_ctrl_b = ctrl_b - lr * gradB_masked
    new_ctrl_c = ctrl_c - lr * gradC_masked

    return cost, new_ctrl_a, new_ctrl_b, new_ctrl_c


## Set Up a Simple Example

We can do a short optimization on the usual 2-qubit gate cost function, but only updating, for instance, the amplitude columns (1 and 3). We'll freeze the detuning (column 0) and the phases (columns 2 and 4).

In [None]:
# Define system parameters quickly
V_int = 2 * np.pi * 10e6
tau = 324e-9
Delta_i = 2 * np.pi * -35.7e6
Rabi_i_val = 2 * np.pi * 100e6
Rabi_r_val = 2 * np.pi * 100e6
del_total = 0.0
Gammas = [0.1, 1.0, 0.05, 0.05]

# Time grid
nt = 200
pad = 20
delta_t = 3e-9

f_std = 50e6
input_dim = 10
numb_ctrl_amps = 5

# Create the JAX-based propagator
prop_jax = PropagatorVLJAX(
    input_dim=input_dim,
    no_of_steps=nt,
    pad=pad,
    f_std=f_std,
    delta_t=delta_t,
    del_total=del_total,
    V_int=V_int,
    Delta_i=Delta_i,
    Rabi_i=Rabi_i_val,
    Rabi_r=Rabi_r_val,
    Gammas_all=Gammas
)

In [None]:
key

### Initialize Controls
We'll set detuning (column 0) to a certain fixed value, phases (columns 2,4) to random small values, but won't be updated. Meanwhile, amplitude columns (1,3) can be random and we'll let them update.

In [None]:
# We'll define ctrl_a, ctrl_b, ctrl_c, each shape [input_dim, 5]

key, sub1, sub2, sub3 = jax.random.split(key, 4)

ctrl_a_init = jnp.zeros((input_dim, numb_ctrl_amps))
ctrl_b_init = jnp.zeros((input_dim, numb_ctrl_amps))
ctrl_c_init = jnp.zeros((input_dim, numb_ctrl_amps))

# Let's put Delta_i(t) in column 0 => we want to freeze this.
# We'll fix it as e.g. -0.8 amplitude
ctrl_a_init = ctrl_a_init.at[:,0].set(-0.8)  

# We'll set amplitude columns 1,3 to random initial values in [-0.2, 0.2]
ctrl_a_init = ctrl_a_init.at[:,1].set(
    jax.random.uniform(sub1, shape=(input_dim,), minval=-0.2, maxval=0.2)
)
ctrl_a_init = ctrl_a_init.at[:,3].set(
    jax.random.uniform(sub2, shape=(input_dim,), minval=-0.2, maxval=0.2)
)

# We'll set phases in columns 2,4 to random small values, but freeze them
ctrl_a_init = ctrl_a_init.at[:,2].set(
    jax.random.uniform(sub3, shape=(input_dim,), minval=-0.3, maxval=0.3)
)
ctrl_a_init = ctrl_a_init.at[:,4].set(
    jax.random.uniform(sub3, shape=(input_dim,), minval=-0.3, maxval=0.3)
)

# For simplicity, ctrl_b_init, ctrl_c_init remain zero:
# they shift the center and width of the Gaussian modes. We'll keep them zero.
print("ctrl_a_init:", ctrl_a_init.shape)


### Define the Gradient Mask

We want to freeze columns 0,2,4. We'll also freeze all of `ctrl_b` and `ctrl_c` for these columns. In other words, no updates for them. The easiest approach is to set the mask to 0 in those columns. We'll set the amplitude columns 1 and 3 to 1, so they get updated.

In [None]:
# Create a mask of shape [input_dim, 5]
mask_col = np.array([0,1,0,1,0], dtype=np.float32)
mask_a = jnp.tile(mask_col, (input_dim,1))

# If we want to keep the center/width of the Gaussians from changing,
# we can simply set all of ctrl_b and ctrl_c to zero masks.
mask_b = jnp.zeros_like(mask_a)
mask_c = jnp.zeros_like(mask_a)

print("mask_a shape =", mask_a.shape)
print(mask_a[0])

### Run the Optimization
We do a small number of steps using our masked single step function.

In [None]:
num_iters = 1000
lr = 0.02

ctrl_a = ctrl_a_init
ctrl_b = ctrl_b_init
ctrl_c = ctrl_c_init
cost_history = []

# Warm-up
cost0, ctrl_a, ctrl_b, ctrl_c = single_optimization_step_masked(
    prop_jax, ctrl_a, ctrl_b, ctrl_c, mask_a, mask_b, mask_c, lr=lr
)

for step in range(num_iters):
    cost_val, ctrl_a, ctrl_b, ctrl_c = single_optimization_step_masked(
        prop_jax, ctrl_a, ctrl_b, ctrl_c, mask_a, mask_b, mask_c, lr=lr
    )
    cost_history.append(np.array(cost_val))

final_cost = cost_history[-1]
print(f"Final cost after {num_iters} steps: {final_cost:.6e}")

## Analyze the Results
Let's see if the amplitude columns changed while the others stayed the same.

In [None]:
plt.figure()
plt.plot(cost_history)
plt.title("Cost vs iteration (partial optimization)")
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.show()

print("Check final ctrl_a:")
print(ctrl_a[0])
print("(We expect columns 0,2,4 to remain close to initial values; columns 1,3 changed)")

### Compare the Final Pulses
We'll retrieve the final waveforms and check if indeed the phases/detuning are the same as we started, while amplitude columns have changed.

In [None]:
final_pulses = prop_jax.return_physical_amplitudes(ctrl_a, ctrl_b, ctrl_c)
final_pulses_np = np.array(final_pulses)
print("Final pulses shape:", final_pulses_np.shape)

# We'll plot the columns
labels = ["Delta_i(t)", "Rabi_i mag", "Rabi_i phase", "Rabi_r mag", "Rabi_r phase"]
tlist = np.linspace(0, prop_jax.duration, final_pulses_np.shape[0])

fig, ax = getStylishFigureAxes(1,1)
for i in range(5):
    PlotPlotter(
        fig,
        ax,
        tlist*1e9,
        final_pulses_np[:, i],
        style={'label': labels[i], 'marker': '', 'linestyle': '-'}
    ).draw()

ax.legend(fontsize=6)
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Dimensionless amplitude")
plt.show()

print("Notice that columns 0,2,4 remain roughly the same shape/time-dependence as initially set.")

## Conclusion

We've shown how to apply a mask to the gradients so that only a desired subset of control parameters (e.g. the amplitude columns for \(\Omega_i\) and \(\Omega_r\)) get updated during the optimization, while the others remain fixed. This helps in scenarios where we only want to optimize certain channels or certain aspects of the pulse (like amplitude, but not phase).