In [1]:
import jax
import jax.numpy as np
from jax import vmap, random, partial

def rot(t):
    return np.array([[np.cos(t), -np.sin(t)],
                     [np.sin(t), np.cos(t)]])
def diag(d1, d2):
    return np.array([[d1,0],[0,d2]])

def rand(key):
    k = random.split(key, 2)
    theta, = jax.random.uniform(k[0], (1,), minval=0, maxval=np.pi) 
    d1,d2,b1,b2,c,d1,d2 \
        = jax.random.uniform(k[1], (7,), minval=-1, maxval=1)

    A = rot(theta).T @ diag(d1, d2) @ rot(theta)
    b = np.array([[b1], [b2]])
    d = np.array([[d1,d2]])
    c = np.array([[c]])
    M = np.vstack([np.hstack([A,b]),
                   np.hstack([c,d])])
    return M

def isStable(M):
    return np.all(np.linalg.eigvals(M) < 0)
def isNash(M):
    return np.all(np.diag(M) < 0)

key = jax.random.PRNGKey(0)
def sample(key, mode='nonnash_stable', tau=10., round_factor=5):
    M = rand(key)
    Gamma = np.array([[1., 1., tau]]).T
    if 'round' in mode:
        M = np.round(M*round_factor)
    
    stable = isStable(M)
    unstable = np.logical_not(stable)
    nash = isNash(M)
    nonnash = np.logical_not(nash)
    stable_fast = isStable(Gamma*M)
    unstable_fast = np.logical_not(stable_fast)
    
    result = True
    if 'nonnash' in mode:
        result = np.logical_and(result, nonnash)
    elif 'nash' in mode:
        result = np.logical_and(result, nash)
    if 'unstable_slow' in mode:
        result = np.logical_and(result, unstable)
    elif 'stable_slow' in mode:
        result = np.logical_and(result, stable)
    if 'unstable_fast' in mode:
        result = np.logical_and(result, unstable_fast)
    elif 'stable_fast' in mode: 
        result = np.logical_and(result, stable_fast)
    return result, M
    
def go(mode, seed=0, batch_size = 1024*4):
    keys = random.split(key, batch_size)
    out = vmap(partial(sample, mode=mode))(keys)
    print("found {} matrices out of {} that satisfy {}".format(len(out[1][out[0]]), batch_size, mode))
    return out
    
def print_eigs(M):
    A,B,C,D = M[0:2,0:2], M[0:2,2,np.newaxis], M[2:,0:2], M[2,2]
    A,B,C,D,M
    Acl = A+B@C
    Gamma = np.array([[1., 1., 10.]]).T

    print(M)
    def fmt(M):
        eigs = np.linalg.eigvals(M)
        desc = "stable" if np.all(np.real(eigs) < 0) else "unstable"
        return desc+" ({:.2f},{:.2f})".format(*eigs)
    print("eig(A)   = " + fmt(A))
    print("eig(A+BC)= " + fmt(A+B@C))
    print("eig(M)   = " + fmt(M))
    print("eig(GM)  = " + fmt(Gamma*M))




In [2]:
idx, Ms = go('nash,stable_slow,stable_fast,rounded')
print_eigs(Ms[idx][0])

found 627 matrices out of 4096 that satisfy nash,stable_slow,stable_fast,rounded
[[-3.  0.  1.]
 [ 0. -3.  1.]
 [ 0. -3. -3.]]
eig(A)   = stable (-3.00+0.00j,-3.00+0.00j)
eig(A+BC)= stable (-3.00+0.00j,-6.00+0.00j)
eig(M)   = stable (-3.00+0.00j,-3.00+1.73j)
eig(GM)  = stable (-3.00+0.00j,-4.16+0.00j)


In [3]:
idx, Ms = go('nash,stable_slow,unstable_fast,rounded')
print_eigs(Ms[idx][5])

found 9 matrices out of 4096 that satisfy nash,stable_slow,unstable_fast,rounded
[[-4. -1. -2.]
 [-1. -1.  0.]
 [-3. -4. -1.]]
eig(A)   = stable (-4.30+0.00j,-0.70+0.00j)
eig(A+BC)= unstable (0.50+2.18j,0.50-2.18j)
eig(M)   = stable (-5.80+0.00j,-0.10+0.92j)
eig(GM)  = unstable (-15.65+0.00j,0.33+1.76j)


In [4]:
idx, Ms = go('nash,unstable_slow,stable_fast,rounded')
print_eigs(Ms[idx][2])

found 8 matrices out of 4096 that satisfy nash,unstable_slow,stable_fast,rounded
[[-2. -1.  0.]
 [-1. -3. -2.]
 [-3. -4. -2.]]
eig(A)   = stable (-1.38+0.00j,-3.62+0.00j)
eig(A+BC)= unstable (-1.19+0.00j,4.19+0.00j)
eig(M)   = unstable (-5.79+0.00j,-1.21+0.00j)
eig(GM)  = stable (-23.96+0.00j,-1.04+0.00j)


In [5]:
idx, Ms = go('nash,unstable_slow,unstable_fast,round')
print_eigs(Ms[idx][2])

found 338 matrices out of 4096 that satisfy nash,unstable_slow,unstable_fast,round
[[-2.  0. -1.]
 [ 0. -2. -4.]
 [ 2. -2. -2.]]
eig(A)   = stable (-2.00+0.00j,-2.00+0.00j)
eig(A+BC)= unstable (-2.00+0.00j,4.00+0.00j)
eig(M)   = unstable (-4.45+0.00j,-2.00+0.00j)
eig(GM)  = unstable (-22.87+0.00j,-2.00+0.00j)


In [6]:
idx, Ms = go('nonnash,stable_slow,stable_fast,round')
print_eigs(Ms[idx][0])

found 90 matrices out of 4096 that satisfy nonnash,stable_slow,stable_fast,round
[[ 1. -1.  3.]
 [-1. -3. -2.]
 [-2.  1. -4.]]
eig(A)   = unstable (1.24+0.00j,-3.24+0.00j)
eig(A+BC)= stable (-2.55+0.00j,-7.45+0.00j)
eig(M)   = stable (-1.00+0.00j,-2.50+0.87j)
eig(GM)  = stable (-37.94+0.00j,-0.52+0.00j)


In [7]:
idx, Ms = go('nonnash,stable_slow,unstable_fast,round')
print_eigs(Ms[idx][0])

found 13 matrices out of 4096 that satisfy nonnash,stable_slow,unstable_fast,round
[[-1.  2. -1.]
 [ 2. -2.  3.]
 [ 0. -3.  1.]]
eig(A)   = unstable (0.56+0.00j,-3.56+0.00j)
eig(A+BC)= stable (-0.08+0.00j,-11.92+0.00j)
eig(M)   = stable (-1.53+0.00j,-0.24+1.79j)
eig(GM)  = unstable (-0.78+0.00j,3.89+7.00j)


In [8]:
idx, Ms = go('nonnash,unstable_slow,stable_fast,round')
Ms[idx]
print_eigs(Ms[idx][0])

found 81 matrices out of 4096 that satisfy nonnash,unstable_slow,stable_fast,round
[[ 2. -1.  3.]
 [-1. -1.  1.]
 [-4.  2. -1.]]
eig(A)   = unstable (2.30+0.00j,-1.30+0.00j)
eig(A+BC)= stable (-6.79+0.00j,-2.21+0.00j)
eig(M)   = unstable (0.85+2.85j,0.85-2.85j)
eig(GM)  = stable (-3.47+7.79j,-3.47-7.79j)


In [9]:
idx, Ms = go('nonnash,unstable_slow,unstable_fast,round')
print_eigs(Ms[idx][0])

found 2930 matrices out of 4096 that satisfy nonnash,unstable_slow,unstable_fast,round
[[ 1.  2.  0.]
 [ 2.  4.  1.]
 [-4.  5.  0.]]
eig(A)   = unstable (0.00+0.00j,5.00+0.00j)
eig(A+BC)= unstable (1.54+0.00j,8.46+0.00j)
eig(M)   = unstable (-1.80+0.00j,1.32+0.00j)
eig(GM)  = unstable (-6.27+0.00j,2.31+0.00j)
