In [None]:
import sys
sys.path.append("./..")

In [None]:
import numpy as np
import scipy.stats as stats
import torch
import matplotlib.pyplot as plt
#
from misc.plot_utils import plot_couplings, plot_capsules, plot_mat, plot_mat2

# Simulate Routing Scenarios

In [None]:
def get_c_perfect_dynamic(n_l, n_h, n_samples, pr=0.0):
    if pr > 0:
        n_samples_p = int(pr * n_samples)
        n_samples = n_samples - n_samples_p
        
    vals = torch.randint(0, n_h, size=(n_samples, n_l))
    C = torch.nn.functional.one_hot(vals).float() * 10
    C = torch.softmax(C + 1e-7, dim=-1).numpy()
    
    if pr > 0:
        C2 = get_c_uniform(n_l, n_h, n_samples_p)
        
        C = np.concatenate([C, C2], axis=0)
    return C

def get_c_uniform(n_l, n_h, n_samples):
    C = torch.softmax(torch.ones(n_samples, n_l, n_h), dim=-1).numpy()
    return C

def get_c_rand(n_l, n_h, n_samples, cs=1, pr=0):
    if pr > 0:
        n_samples_p = int(pr * n_samples)
        n_samples = n_samples - n_samples_p
        
    C = torch.softmax(torch.rand(n_samples, n_l, n_h) * cs, dim=-1).numpy()
    
    if pr > 0:
        C2 = get_c_uniform(n_l, n_h, n_samples_p)
        
        C = np.concatenate([C, C2], axis=0)
    return C

def get_c_rate_strength(n_l, n_h, n_samples, cs=None, cr=None, nc=None, pr=0.0):
    """
        pr: pasive rate
        cs: Coupling Strength
        cr: Coupling Rate
        (nc: number of couplings)
        use either cr or nc
    """
    if pr > 0:
        n_samples_p = int(pr * n_samples)
        n_samples = n_samples - n_samples_p
    
    if nc is None:
        nc = int(cr * n_h)
    #
    assert nc > 0
    #
    CC = [get_c_static(n_l, n_h, n_samples // nc, cs=cs) for _ in range(nc)]
    
    if pr > 0:
        C = get_c_uniform(n_l, n_h, n_samples_p)
        CC.append(C)
    
    C = np.concatenate(CC, axis=0)
    return C

def get_c_static(n_l, n_h, n_samples, cs=10, pr=0):
    if pr > 0:
        n_samples_p = int(pr * n_samples)
        n_samples = n_samples - n_samples_p
    vals = torch.randint(0, n_h, size=(n_l,))
    C = torch.nn.functional.one_hot(vals, num_classes=n_h).float() * cs
    C = torch.softmax(C + 1e-7, dim=-1)
    C = C.unsqueeze(0)
    C = C.repeat(n_samples, 1, 1).numpy()
    
    if pr > 0:
        C2 = get_c_uniform(n_l, n_h, n_samples_p)
        C = np.concatenate([C, C2], axis=0)
    
    return C

def calc_norm_entropy(C):
    Cm = C.mean(axis=0)
    Ce = np.sum(Cm * np.log(Cm) * (1/np.log(Cm.shape[1])), axis=1) * -1
    return Ce

## Mean Max & Adjusted Mean Max

In [None]:
def mean_max(C):
    """
        Indicating strong couplings for lower layer Capsules
        Only For ACTIVE Capsules interpretable!
        C (b, n_l, n_h) c in C in [0, 1]
    """
    return np.max(C, axis=2).mean(axis=0)

def adjusted_mean_max(C, pr):
    """
        mean_max(C), adjusted for pr (0,1) passive capsules
        pr: either a scalar 
            or a vector of activity weights for each capsule
    """
    n_samples, n_l, n_h = C.shape
    assert type(pr) == float or type(pr) == np.ndarray
    amm = (C.max(axis=2).sum(axis=0) - (pr * n_samples) / n_h) / ((1 - pr) * n_samples)
    return amm

In [None]:
n_l = 16
n_h = 8
pr = np.array([0.9]*n_l)
n_samples = 1000
C = get_c_perfect_dynamic(n_l, n_h, n_samples, pr=pr[0])
#
C0 = C
C1 = C[:100,:,:]
C2 = C[100:, :, :]

In [None]:
mean_max(C0)

In [None]:
mean_max(C1)

In [None]:
adjusted_mean_max(C0, pr)

## Adjusted Mean and Adjusted Std

In [None]:
def adjusted_mean_old(C, pr):
    n_samples, n_l, n_h = C.shape
    am = (C.sum(axis=0) - (pr * n_samples) / n_h) / ((1 - pr) * n_samples)
    return am

def adjusted_mean(C, pr):
    n_samples, n_l, n_h = C.shape
    if type(pr) == float:
        pr = np.array([pr]*n_l).reshape(n_l,1)
    else:
        pr = pr.reshape(n_l, 1)
    am = (C.sum(axis=0) - (pr * n_samples) / n_h) / ((1 - pr) * n_samples)
    return am

def adjusted_std_old(C, pr):
    n_samples, n_l, n_h = C.shape
    amc = adjusted_mean_old(C, pr)
    asd = (np.sum((C - amc)**2, axis=0) - (pr * n_samples * (1/n_h - amc)**2)) / ((1 - pr) * n_samples)
    asd = np.sqrt(asd + 1e-5)
    return asd

def adjusted_std(C, pr):
    n_samples, n_l, n_h = C.shape
    if type(pr) == float:
        pr = np.array([pr]*n_l).reshape(n_l,1)
    else:
        pr = pr.reshape(n_l, 1)
    am = adjusted_mean(C, pr)
    asd = (np.sum((C - am)**2, axis=0) - (pr * n_samples * (1/n_h - am)**2)) / ((1 - pr) * n_samples)
    assert asd.min() > -1e-6, "numerics {}".format(asd.min())
    asd = np.maximum(0,asd)
    asd = np.sqrt(asd)
    return asd

In [None]:
n_l = 16
n_h = 8
pr = np.array([0.9] * n_l)
n_samples = 1000
C = get_c_perfect_dynamic(n_l, n_h, n_samples, pr=pr[0])
#
C0 = C
C1 = C[:100,:,:]
C2 = C[100:, :, :]

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(24, 8))
plot_mat2(C0.mean(axis=0), ax=axes[0], title="m(C)")
plot_mat2(C1.mean(axis=0), ax=axes[1], title="m(C_active)")
plot_mat2(adjusted_mean_old(C0, pr[0]), ax=axes[2], title="old am(C)")
plot_mat2(adjusted_mean(C0, pr), ax=axes[3], title="am(C)")
plot_mat2(C2.mean(axis=0), ax=axes[4], title="m(C_passive)")

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(24, 8))
plot_mat2(C0.std(axis=0), ax=axes[0], title="std(C)")
plot_mat2(C1.std(axis=0), ax=axes[1], title="std(C_active)")
plot_mat2(adjusted_std_old(C0, pr[0]), ax=axes[2], title="old_astd(C)")
plot_mat2(adjusted_std(C0, pr), ax=axes[3], title="astd(C)")
plot_mat2(C2.std(axis=0), ax=axes[4], title="std(C_passive)")

# Dynamic Coefficient

In [None]:
def dynamics2(C):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    msd = C.std(axis=0).mean(axis=1)
    mm = mean_max(C)
    dyn = msd / (mm * std_pr)
    return dyn

def dynamics(C):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    msd = C.std(axis=0).mean(axis=1)
    mm = C.max(axis=(0,2))
    dyn = msd / (mm * std_pr)
    return dyn

def adjusted_dyn_old(C, pr):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    masd = adjusted_std(C, pr).mean(axis=1)
    mm = C.max(axis=(0,2))
    dyn = masd / (mm * std_pr)
    return dyn

def adjusted_dynamics2(C, pr):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    masd = adjusted_std(C, pr).mean(axis=1)
    mma = adjusted_mean_max(C, pr)
    dyn = masd / (mma * std_pr)
    return dyn

def adjusted_dyn(C, pr):
    n_samples, n_l, n_h = C.shape
    std_pr = np.sqrt(1/n_h * (1 - 1/n_h))
    masd = adjusted_std(C, pr).mean(axis=1)
    #mma = adjusted_mean_max(C, pr)
    mx = C.max(axis=(0,2)) 
    dyn = masd / (std_pr * mx)
    return dyn

In [None]:
n_l = 16
n_h = 8
pr = np.array([0.9] * n_l)
n_samples = 10000
C = get_c_perfect_dynamic(n_l, n_h, n_samples, pr=pr[0])
#
C0 = C
C1 = C[:100,:,:]
C2 = C[100:, :, :]
#

In [None]:
adjusted_dyn(C, pr)

In [None]:
dynamics(C1)

# Test Metrics

In [None]:
def print_metrics(C, pr):
    mdy = adjusted_dyn(C, pr).mean()
    mma = adjusted_mean_max(C, pr).mean()
    print("mdy: {:.3f}, mma: {:.3f}".format(mdy, mma))

### Perfect Routing

In [None]:
n_l = 20
n_h = 10
n_samples = 10000

C = get_c_perfect_dynamic(n_l, n_h, n_samples, pr=0)
#
C_mu = C.mean(axis=0)
C_sd = C.std(axis=0)
C_mx = C.max(axis=0)
#
fig, axes = plt.subplots(1, 3, figsize=(20, 10))
plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=1, title="mean")
plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5, title="sd")
plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=1, title="max")
plt.show()

In [None]:
for n_l in [20, 40, 60]:
    for n_h in [10, 20, 40]:
        for pr in [0, 0.1, 0.4, 0.8]:
            pr = np.array([pr] * n_l)
            C = get_c_perfect_dynamic(n_l, n_h, n_samples, pr=pr[0])
            print_metrics(C, pr)

### Uniform Routing

In [None]:
n_l = 20
n_h = 10
n_samples = 10000

C = get_c_uniform(n_l, n_h, n_samples)
C_mu = C.mean(axis=0)
C_sd = C.std(axis=0)
C_mx = C.max(axis=0)
#
fig, axes = plt.subplots(1, 3, figsize=(20, 10))
plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=1, title="mean")
plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5, title="sd")
plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=1, title="max")
plt.show()

In [None]:
pr = 0
n_samples = 10000
for n_l in [20, 40, 60]:
    for n_h in [10, 20, 40]:
        C = get_c_uniform(n_l, n_h, n_samples)
        print_metrics(C, pr=np.array([pr] * n_l))

### Random Routing

In [None]:
n_l = 20
n_h = 10
n_samples = 10000

C = get_c_rand(n_l, n_h, n_samples, cs=4)
C_mu = C.mean(axis=0)
C_sd = C.std(axis=0)
C_mx = C.max(axis=0)
#
fig, axes = plt.subplots(1, 3, figsize=(20, 10))
plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=1, title="mean")
plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5, title="sd")
plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=1, title="max")
plt.show()

In [None]:
for n_l in [20, 40, 60]:
    for n_h in [10, 20, 40]:
        C = get_c_rand(n_l, n_h, n_samples, pr=0, cs=8)
        print_metrics(C, np.array([0] * n_l))

In [None]:
n_l = 30
n_h = 10
for pr in [0, 0.2, 0.4, 0.8]:
    pr = np.array([pr] * n_l)
    C = get_c_rand(n_l, n_h, n_samples, pr=pr[0], cs=8)
    print_metrics(C, pr)

In [None]:
n_l = 30
n_h = 10
for pr in [0, 0.2, 0.4, 0.8]:
    pr = np.array([pr] * n_l)
    C = get_c_rand(n_l, n_h, n_samples, pr=pr[0], cs=3)
    print_metrics(C, pr)

### Static Routing

In [None]:
n_l = 20
n_h = 10
#
n_samples = 10000

C = get_c_static(n_l, n_h, n_samples, cs=4, pr=0.1)
C_mu = C.mean(axis=0)
C_sd = C.std(axis=0)
C_mx = C.max(axis=0)

fig, axes = plt.subplots(1, 3, figsize=(30, 10))
plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=1)
plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=1)
plt.show()

In [None]:
n_l = 20
n_h = 10
for pr in [0, 0.2, 0.4, 0.6]:
    pr = np.array([pr] * n_l)
    C = get_c_static(n_l, n_h, n_samples, cs=4, pr=pr[0])
    print_metrics(C, pr)

In [None]:
for n_l in [30, 40, 60]:
    for n_h in [10, 20, 30]:
        pr = np.array([0] * n_l)
        C = get_c_static(n_l, n_h, n_samples, cs=4, pr=pr[0])
        print_metrics(C, pr)

In [None]:
#(0.5 / 0.2)
#masd: vor summe durch max, um relatic skala

In [None]:
for n_l in [30, 40, 60]:
    for n_h in [10, 20, 30]:
        pr = np.array([0] * n_l)
        C = get_c_static(n_l, n_h, n_samples, cs=8, pr=pr[0])
        print_metrics(C, pr)

In [None]:
for n_l in [30, 40, 60]:
    for n_h in [10, 20, 30]:
        pr = np.array([0] * n_l)
        C = get_c_static(n_l, n_h, n_samples, cs=2, pr=pr[0])
        print_metrics(C, pr)

### More Dynamic Routing

In [None]:
n_l = 20
n_h = 10
pr = np.array([0] * n_l)
nc = 2
cs = 10
#
C = get_c_rate_strength(n_l, n_h, n_samples, cs=cs, nc=nc, pr=pr[0])
C_mu = C.mean(axis=0)
C_sd = C.std(axis=0)
C_mx = C.max(axis=0)

fig, axes = plt.subplots(1, 3, figsize=(30, 10))
plot_mat2(C_mu, ax=axes[0], vmin=0, vmax=1)
plot_mat2(C_sd, ax=axes[1], vmin=0, vmax=0.5)
plot_mat2(C_mx, ax=axes[2], vmin=0, vmax=1)
plt.show()

In [None]:
n_l = 20
n_h = 10
pr = np.array([0] * n_l)
nc = 2
cs = 4

for nc in [2, 4, 6, 8]:
    C = get_c_rate_strength(n_l, n_h, n_samples, cs=cs, nc=nc, pr=pr[0])
    print_metrics(C, pr)

In [None]:
n_l = 20
n_h = 10
pr = np.array([0.5] * n_l)
nc = 2
cs = 4

for nc in [2, 4, 6, 8]:
    C = get_c_rate_strength(n_l, n_h, n_samples, cs=cs, nc=nc, pr=pr[0])
    print_metrics(C, pr)

In [None]:
n_l = 20
n_h = 10
pr = np.array([0] * n_l)
nc = 3

for cs in [2, 4, 6, 8, 10]:
    C = get_c_rate_strength(n_l, n_h, n_samples, cs=cs, nc=nc, pr=pr[0])
    print_metrics(C, pr)

In [None]:
n_l = 20
n_h = 10
pr = np.array([0.5] * n_l)
nc = 3

for cs in [2, 4, 6, 8, 10]:
    C = get_c_rate_strength(n_l, n_h, n_samples, cs=cs, nc=nc, pr=pr[0])
    print_metrics(C, pr)