# Interactive Visualization of Attention Scores
---

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [None]:
import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import ipywidgets as widgets
from functools import partial
from IPython.display import display

from demo_helpers import (
    SingleQubitRotationControls,
    TwoQubitsRotationControls,
    StepControls,
    QuantumState,
    InteractiveAttentionScores,
    get_special_states,
    apply_1q_rotation,
    apply_2q_rotation
)

Global variables

In [None]:
SEED = 4
DEVICE = 'cpu'
N_QUBITS = 4
STEPS_LIMIT = 40
AGENT_PATH = 'logs/4q_pGen_0.9_attnHeads_2_tLayers_2_ppoBatch_512_entReg_0.1_embed_128_mlp_256/agent.pt'

Set seeds

In [None]:
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
np.set_printoptions(5, suppress=True)
torch.set_printoptions(5, sci_mode=False)

Load agent

In [None]:
agent = torch.load(AGENT_PATH, map_location=torch.device(DEVICE))

Initialize state

In [None]:
for k in get_special_states()[4]:
    print(k)

In [None]:
state = get_special_states()[4]["|RRRR>"]
state

In [None]:
# Initialize Attention plot
attn_scores = InteractiveAttentionScores(N_QUBITS, agent)
n_layers, n_heads, _, _ = attn_scores.get_attention_scores(state).shape

fig, axs = plt.subplots(
    nrows=2,
    ncols=4,
    figsize=(15, 5),
    tight_layout={"pad":2},
    dpi=60
)
for ax in axs[1, :]:
    ax.remove()

attn_axs = axs[0, :]
prob_ax = plt.subplot2grid(shape=(2, 4), loc=(1, 0), colspan=4, fig=fig)
attn_scores.set_attn_axes(attn_axs)
attn_scores.set_prob_ax(prob_ax)

# Initialize controls
qstate = QuantumState(state, handler=attn_scores.update)
q1_rotate_func = partial(apply_1q_rotation, qstate)
q2_rotate_func = partial(apply_2q_rotation, qstate)
q1_rotation_controls = SingleQubitRotationControls(4, q1_rotate_func)
q2_rotation_controls = TwoQubitsRotationControls(4, q2_rotate_func)
step_controls = StepControls(agent, qstate, [q1_rotation_controls, q2_rotation_controls])

# Display controls
q1_rotation_controls.display()
q2_rotation_controls.display()
step_controls.display()
qstate.display()