In [1]:
# import gym
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from IPython.display import clear_output
from matplotlib import pyplot as plt

from timeit import default_timer as timer
from datetime import timedelta

import math
import random

from qutip import Bloch, Bloch3d, basis
from qutip.qip.operations import rz

In [19]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def plot(axs, frame_idx, rewards, losses, sigma, contextual_bias, mu, elapsed_time):
    axs[0].set_title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))
    axs[0].plot(rewards)
    if losses and sigma:
        axs[1].set_title('fidelity')
        axs[1].plot(losses, 'g')
        axs[1].plot(sigma, 'r')
    if contextual_bias:
        axs[2].set_title('contextual_bias')
        axs[2].plot(contextual_bias)
        axs[2].axhline(y=mu, color='g', linestyle='-')

In [4]:
def bloch_viz(b, error_samples, error_draw, contextual_bias, correction):  
    b.clear()
    error_samples = np.random.choice(error_samples, 100)
    kets = [rz(error_sample) * (basis(2, 0) + basis(2, 1)).unit() for error_sample in error_samples]
    b.add_states(kets, "point")
    b.add_states([
        rz(error_draw) * (basis(2, 0) + basis(2, 1)).unit(),
        rz(-contextual_bias) * (basis(2, 0) + basis(2, 1)).unit(),
        rz(-correction) * (basis(2, 0) + basis(2, 1)).unit()
        ])
    b.point_color = ["r"]
    b.show()

In [53]:
# (action, reward) distribution for a given "spectator measurement outcome" state
class Context:
    def __init__(self, gamma, eta):
        # discount factor
        self.gamma = gamma
        self.eta = eta

        self.batch_correction_feedback = [[], [], []]
        self.batch_context_feedback = [[], [], []]
        self.context_theta = [0, 0, 0]
        self.correction_theta = [0, 0, 0]

    def reset(self):
        self.batch_correction_feedback = []
        self.batch_context_feedback = []

    def discount(self):
        pass

    def update_gamma(self, gamma):
        self.gamma = gamma

    def update_batch_feedback(self, context_feedback, correction_feedback):
        # feedback is given per theta
        for idx, f in enumerate(context_feedback):
            self.batch_context_feedback[idx].append(f)
        
        for idx, f in enumerate(correction_feedback):
            self.batch_correction_feedback[idx].append(f)

    def get_and_combine_optimal_theta(self):
        if len(self.batch_correction_feedback) == 0:
            return self.correction_theta, self.context_theta

        # feedback is given per theta
        for idx, f in enumerate(self.batch_correction_feedback):
            lo = [r[0] for r in f]
            mid = [r[1] for r in f]
            hi = [r[2] for r in f]
            
            mu_plus = np.mean(hi)
            mu = np.mean(mid)
            mu_minus = np.mean(lo)

            var_grad = np.mean([2 * (m - mu)  * (h - l) for l,m,h in zip(lo,mid,hi)])
            grad = mu_plus - mu_minus

            self.correction_theta[idx] += self.eta * grad
        
        for idx, f in enumerate(self.batch_context_feedback):
            lo = [r[0] for r in f]
            mid = [r[1] for r in f]
            hi = [r[2] for r in f]
            
            mu_plus = np.mean(hi)
            mu = np.mean(mid)
            mu_minus = np.mean(lo)

            grad = mu_plus - mu_minus

            self.context_theta[idx] += self.eta * var_grad

        self.reset()
        
        return self.correction_theta, self.context_theta

    def get_optimal_theta(self):
        return self.correction_theta

In [54]:
# contextual analytic geometric descent
class Analytic2D:
    def __init__(self, env, initial_gamma=0.99, eta=np.pi/64):
        # two contexts -> pos vs neg rotation with respect to a chosen rotational basis
        self.contexts = [Context(initial_gamma, eta), Context(initial_gamma, eta)]

        self.rewards = []
        self.fidelity = []
        self.control_fidelity = []

        self.num_context_spectators = env.num_context_spectators
        self.num_reward_spectators = env.num_reward_spectators

        # step size
        self.eta = eta

    def get_actions(self, observations):
        # our context is an array of binary spectator qubit measurements
        # hence, we could convert this binary array to an integer and index 2^(spectator qubits) contexts
        # context = self.contexts[np.packbits(observation, bitorder='little')[0]] 
        # for now, we only have two contexts (+ vs -), and so we consider spectators to be indistinguishable noise polling devices
        # in the future, we may consider noise gradients and so we do indeed need to track the specific arrangement

        actions = []
        for observation in observations:
            context_idx = 1 if np.sum(observation) > self.num_context_spectators / 2 else 0
            context = self.contexts[context_idx]

            # (theta vec)
            actions.append(context.get_and_combine_optimal_theta())
        return actions

    def update(self, context_feedback, correction_feedback, observations):
        for context in self.contexts:
            context.reset()

        for _context_feedback, _correction_feedback, observation in zip(context_feedback, correction_feedback, observations):
            context_idx = 1 if np.sum(observation) > self.num_context_spectators / 2 else 0
            context = self.contexts[context_idx]
            context.update_batch_feedback(_context_feedback, _correction_feedback)

    def _update_gammas(self, gamma):
        for context in self.contexts:
            context.update_gamma(gamma)

    def save_reward(self, reward):
        self.rewards.append(reward)

    def save_fidelity(self, fidelity):
        self.fidelity.append(fidelity)

    def save_control_fidelity(self, fidelity):
        self.control_fidelity.append(fidelity)

In [55]:
M = 1000
BATCH_SIZE = 100

MU = np.pi / 16
SIGMA = np.pi / 8

# add some random harmonics
# time_dependent_fn = np.vectorize(lambda x: (2/5) * np.pi * np.sin(2 * np.pi * x / M))
error_samples = np.random.normal(MU, SIGMA, M) # + time_dependent_fn(np.arange(M))
# error_samples = np.random.choice([-np.pi/4, np.pi / 4], M)

In [56]:
from spectator_env import SpectatorEnvContinuousV2

# describes MDP
# - states are given in terms of a #context_spectators bit measurement outcomes
# - continuous action space is given by ([-pi, pi], smoothing parameter, contextual_measurement_bias) i.e. \U(1) \times U(1) \times U(1)
env = SpectatorEnvContinuousV2(error_samples, batch_size=BATCH_SIZE, num_context_spectators=64, num_reward_spectators=64, context_sensitivity=2.0, reward_sensitivity=2.0)


In [57]:
start = timer()

md = Analytic2D(env, initial_gamma=1.0, eta=np.pi/8)

episode_reward = 0
episode_fidelity = []
control_fidelity = []

MAX_FRAMES = 10 * M // BATCH_SIZE
UPDATE_ERROR_SAMPLES_FRAMES = M // BATCH_SIZE

PLOT_BLOCH = False
PLOT_REWARD = True

# "episodes" are a reasonable way to think about learning a periodic time dependent function
# max frames = episode length * num episodes
# episodes are identical sequences of training data
observation = env.reset()

for frame_idx in range(1, MAX_FRAMES + 1):
    actions = md.get_actions(observation)
    prev_batch = env.error_samples_batch
    prev_observation=observation
    observation, feedback, done, info = env.step(actions)
    observation = None if done else observation

    # generally, in RL we would consider r(s | s', a)
    # given a state transition s' -> s due to action a
    # for now, we are only interested in r(s', a)
    md.update(actions, reward, prev_observation)
    episode_reward += np.sum(reward)

    for x, y in info:
        episode_fidelity.append(x)
        control_fidelity.append(y)

    if UPDATE_ERROR_SAMPLES_FRAMES:
        new_error_samples = np.random.normal(MU, SIGMA, M)
        env.set_error_samples(new_error_samples)

    if done:
        if PLOT_BLOCH or PLOT_REWARD:
            clear_output(True)
        if PLOT_BLOCH:
            bloch_viz(b, env.error_samples, prev_batch[0], md.action_contextual_bias, actions[0][0])
        if PLOT_REWARD and len(md.rewards) > 0 and len(md.fidelity) > 0 and len(md.contextual_biases) > 0:
            fig, axs = plt.subplots(1, 3, figsize=(20,5))
            plot(axs, frame_idx, md.rewards, md.fidelity, md.control_fidelity, md.contextual_biases, -MU,
                timedelta(seconds=int(timer()-start)))
            plt.show()

        observation = env.reset()
        md.save_reward(episode_reward)
        md.save_fidelity(np.mean(episode_fidelity))
        md.save_control_fidelity(np.mean(control_fidelity))
        episode_reward = 0
        episode_fidelity = []
        control_fidelity = []

env.close()



  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


ValueError: cannot convert float NaN to integer

In [10]:
print(np.mean(md.contextual_biases))

-0.13928068662033852


In [11]:
print(md.contexts[1].get_optimal_theta())

[-0.54754075 -0.36648637 -0.71474008]


In [12]:
print(md.action_contextual_bias, MU)

[-0.21649348 -0.17622204 -0.19074786] 0.19634954084936207
