In [416]:
import sys
sys.path.append('..')
import numpy as np
import jax.numpy as jnp
from jax import vmap
import benpy as bp
%gui qt
import pyqtgraph as pg
from importlib import reload  

bp = reload(bp)
from pyqtgraph.Qt import QtCore

In [442]:
params = [('a1', -1.),
          ('a2', -.5),
          ('theta_a', 0.),
          ('d1', -1.),
          ('d2', -.5),
          ('theta_d', 0.),
          ('p1', -1),
          ('p2', 1.),
          ('theta_p', 0.),
          ('k1', -1.),
          ('k2', 1.),
          ('theta_k', 0.)]
params = {p:v for p,v in params}

    
def load(a1,a2,theta_a, d1,d2,theta_d,
         p1,p2,theta_p, k1,k2,theta_k,**kwargs):
    rot = lambda x: bp.rotation2(x,np=jnp)
    diag = lambda x,y: np.array([[x,0],[0,y]])
    
    Ra = rot(theta_a)
    Rd = rot(theta_d)#*np.pi/8)
    Rp = rot(theta_p)#*np.pi/8)
    Rk = rot(theta_k)#*np.pi/8)
    
    Sa = diag(a1,a2)
    Sd = diag(d1,d2)
    Sk = diag(k1,k2)
    Sp = diag(p1,p2)
    
    A = Ra@Sa@Ra.T
    D = Rd@Sd@Rd.T
    P = Rp@Sp@Rp.T
    K = Rk@Sk@Rk.T
    
    return ((A,P-K), (P.T+K.T, D))
    
def isStable(M):
    return jnp.all(jnp.real(jnp.linalg.eigvals(M)<0))

def run(params):
    M = load(**params)
    J = bp.block(M, np=jnp)
    eigs = np.linalg.eigvals(J)
    qnum = bp.numrange(M, num=1e2)
    return eigs, qnum 

In [446]:
color1 = "#F5E800"
color2 = "#1ACDEA"
color3 = "#E1851A"
color4 = "#8C0053"
if win: win.close()
pg.setConfigOptions(background=(255,255,255))
win = pg.GraphicsLayoutWidget()
win.resize(800,400)
win.show()

l_matrix = win.addLabel('matrix')
p_eigs = win.addPlot()
plt_qnum = pg.PlotDataItem(symbolSize=2, symbol='o', pen=None, symbolPen=None, symbolBrush=color1)
plt_eigs = pg.PlotDataItem(symbolSize=10, symbol='x', pen=None, symbolPen=color2, symbolBrush=color2)
plt_diags1 = pg.PlotDataItem(symbolSize=6, pen=None, symbol='o', symbolPen=color3, symbolBrush=color3)
plt_diags2 = pg.PlotDataItem(symbolSize=6, pen=None, symbol='o', symbolPen=color4, symbolBrush=color4)
p_eigs.addItem(plt_qnum)
p_eigs.addItem(plt_eigs)
p_eigs.addItem(plt_diags1)
p_eigs.addItem(plt_diags2)
p_eigs.setLabels(bottom='Real', left='Imaginary')

p_rot = []
plt_imv = []
plt_pair = []

for _ in range(6):
    p_rot.append(win.addPlot(row=_//2+1, col=_%2))
    plt_imv.append(pg.ImageItem(scale=(0.1,0.1)))
    plt_pair.append(pg.PlotDataItem(symbolSize=5, symbol='s'))
    plt_imv[-1].show() 
    p_rot[-1].addItem(plt_imv[-1])
    p_rot[-1].addItem(plt_pair[-1])

In [447]:
params = [('a1', -1.),
          ('a2', .05),
          ('theta_a', 2.),
          ('d1', -2.),
          ('d2', -1.),
          ('theta_d', 3.4),
          ('p1', 0.),
          ('p2', 0.),
          ('theta_p', .0),
          ('k1', -1.),
          ('k2', 1.),
          ('theta_k', 0.)]

xlim = [0, 2*np.pi]
ylim = [0, 2*np.pi]

params = {p:v for p,v in params}

eigs, qnum = run(params)
c = np.max(np.abs(np.real(qnum)))
p_eigs.setRange(xRange=(-c,c), yRange=(-c,c))

plt_eigs.setData(np.real(eigs), np.imag(eigs))
plt_qnum.setData(np.real(qnum), np.imag(qnum))
plt_diags1.setData([params['a1'], params['a2']],[0,0])
plt_diags2.setData([params['d1'], params['d2']],[0,0])

M = load(**params)
(A,B),(C,D) = M
J = bp.block(M)
l_matrix.setText("M=<br>" + "<br>".join([
    ' '.join(['{: 2.2f}'.format(m)
                    for m in row]) 
    for row in J]))

bp = reload(bp)
grid, shape = bp.grid(xlim=xlim, ylim=ylim, xnum=64, ynum=64)


def scan(s):
    myparams = dict()
    myparams.update(params)
    [myparams.pop(_) for _ in s]
    def sample(thetas, np=jnp):
        t1,t2 = thetas
        p = {_:t for _,t in zip(s,thetas)}
        M = load(**p, **myparams)
        J = bp.block(M, np=jnp)
        return np.max(np.real(np.linalg.eigvals(J)))
    return sample

s = [['theta_a','theta_d'],
    ['theta_p','theta_k'],
    ['theta_a','theta_k'],
    ['theta_d','theta_k'],
    ['theta_a','theta_p'],
    ['theta_d','theta_p']]
     
data_min = np.inf
data_max = -np.inf
for i,s in enumerate(s):
    sample = scan(s)
    data = np.array(vmap(sample)(grid).reshape(shape))
    data_min = np.minimum(np.min(data), data_min)
    data_max = np.maximum(np.max(data), data_max)
    plt_imv[i].setImage(data)
    plt_imv[i].setRect(QtCore.QRectF(xlim[0], ylim[0], xlim[1], ylim[1]))#/np.pi, ylim[1]/np.pi))
    p_rot[i].setLabels(bottom=s[0], left=s[1])
    plt_pair[i].setData([params[s[0]]], [params[s[1]]])
    
for p in plt_imv:
    p.setLevels([data_min, data_max])