In [1]:
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import qutip as qt
from collections import defaultdict

import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import nn

import qcontrol
device = qcontrol.device
print("Using device", device)

# Set seed for repeatability
env = qcontrol.QuantumEnv()
torch.manual_seed(0)
env.set_seed(0);

Using device cuda:0
Using device cuda:0


## Define the Problem

In [2]:
dt = 0.005
gate_time = 10
control_thresh = 1
control_penalty = 2
close_thresh = 0.99999

np_ctype = np.complex64
psi_0 = np.array([[1,  0],
                    [0,  1]]).astype(np_ctype)
psi_f = np.array([[0,  1],
                    [1,  0]]).astype(np_ctype)
H0 = np.array([[1,  0],
                [0, -1]]).astype(np_ctype)
H1 = np.array([[0,  1],
                [1,  0]]).astype(np_ctype)


## Define the Network

In [3]:

class agentNet(nn.Module):

    def __init__(self):
        super(agentNet, self).__init__()
        ctype = torch.float32
        p_drop = 0.01
        bias = True
        self.net = nn.Sequential(
                    nn.Linear(2*psi_0.size, 64, dtype=ctype, bias=bias),
                    nn.Tanh(),
                    nn.Linear(64, 64, dtype=ctype, bias=bias),
                    nn.Tanh(),
                    nn.Dropout(p=p_drop),
                    nn.Linear(64, 16, dtype=ctype, bias=bias),
                    nn.Tanh(),
                    nn.Dropout(p=p_drop),
                    nn.Linear(16, 1, dtype=ctype, bias=bias),
                    )

    def forward(self, x1, x2):
        flattened = torch.hstack((x1.flatten(1, -1), x2.flatten(1, -1)))
        return self.net(flattened)

agent = agentNet().to(device)
policy = TensorDictModule(agent,
                          in_keys=["psi_real", "psi_imag"],
                           out_keys=["action"])

In [4]:
 optim = torch.optim.Adam(policy.parameters(), lr=2e-3)

# Train Parameters
batch_size = 100
rollout_len = 1500
total_trials = 10000
pbar = tqdm.tqdm(range(total_trials // batch_size))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_trials)
logs = defaultdict(list)


params = env.gen_params(batch_size=[batch_size], dt=dt,
                        gate_time=gate_time, control_thresh=control_thresh,
                        control_penalty=control_penalty, psi_0=psi_0,
                        psi_f=psi_f, H0=H0, H1=H1, close_thresh=close_thresh)

for _ in pbar:
    init_td = env.reset(params)
    rollout = env.rollout(rollout_len, policy, tensordict=init_td, auto_reset=False)
    traj_return = rollout["next", "reward"].mean()
    (-traj_return).backward()
    gn = torch.nn.utils.clip_grad_norm_(agent.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    pbar.set_description(
        f"reward: {traj_return: 4.4f}, "
        f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
    )
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
    scheduler.step()

  0%|          | 0/100 [00:00<?, ?it/s]

reward: -2.8557, last reward: -1.8894, gradient norm:  4.995:   4%|▍         | 4/100 [00:11<04:38,  2.90s/it]

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(logs["return"])
plt.title("returns")
plt.xlabel("iteration")
plt.subplot(1, 2, 2)
plt.plot(logs["last_reward"])
plt.title("last reward")
plt.xlabel("iteration")

In [None]:
dt = 0.005
timesteps = 1500
tlist = np.arange(timesteps)*dt

def get_control_pulse(net, timesteps=timesteps, dt=dt, ctype=torch.complex64):
    
    net.eval()
    psi = torch.tensor(psi_0, dtype=ctype, device=device)
    H0_tensor = torch.tensor(H0, dtype=ctype, device=device)
    H1_tensor = torch.tensor(H1, dtype=ctype, device=device)
    cntrl_seq = torch.zeros(timesteps, dtype=ctype)
    psi_list = []
    for i in range(timesteps):
        cntrl = net(torch.real(psi).reshape((1,4,1)), torch.imag(psi).reshape((1,4,1)))
        cntrl_seq[i] = cntrl
        U = torch.linalg.matrix_exp(-1.0j*dt*(H0_tensor+cntrl*H1_tensor))
        psi = U@psi
        psi_list.append(psi.cpu().detach().numpy())

    return cntrl_seq.cpu().detach().numpy(), psi_list

pulse, psi_list = get_control_pulse(net)

def Ham(t, tlist):
    tdiff = np.abs(tlist - t)
    return qt.Qobj(H0 + pulse[np.argmin(tdiff)]*H1)

states_qt = qt.sesolve(lambda t, args: Ham(t, tlist), qt.Qobj(psi_0), tlist=tlist)
prob0_qt = [np.abs(s[0][0][0])**2 for s in states_qt.states]

prob0_start0 = [np.abs(s[0,0])**2 for s in psi_list]
prob0_start1 = [np.abs(s[0,1])**2 for s in psi_list]

f, ax = plt.subplots(ncols=2)
ax[0].set_title("pulse")
ax[1].set_title("population in [1,0] state")
ax[0].plot(tlist, pulse)
# ax[1].plot(tlist, prob0_qt, label="qutip")
ax[1].plot(tlist, prob0_start0, label="starts in [1,0]")
ax[1].plot(tlist, prob0_start1, label="starts in [0,1]")
ax[1].legend()
plt.savefig("pulse.png")

KeyboardInterrupt: 