In [322]:
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from scipy import linalg
from scipy.optimize import linprog
from numpy.linalg import matrix_rank as rank
import torch

In [323]:
# Models
import sys
sys.path.append("..")
from models import MLP
from metrics import jacobian
from data_generation import RandomPolynomialMapping

In [324]:
def check_global(A,v,X,y,d0,d1,n):
    """
    Solves a linear program to check if there exists a global solution to the optimization program
    within a given region.
    
    Inputs:
    - A is an (d1 x n) binary array representing the preactivation positive activations
    - v is a (d1 x1) array holding the output weights of the network
    - X is an (d0, n) array holding the data features
    - y is an (n,1) array holding the data targets
    
    """
    B = -2*A+1
    Eq = np.zeros((n, d0*d1))
    Iq = np.zeros((n*d1, d0*d1))
    # Form equality and inequality matrices for linear program
    for i in range(n):
        x = X[:,i]
        for j in range(d1):
            start = j*d0
            end = (j+1)*d0
            Iq[(i*d1)+j,start:end]=B[j,i]*x
            Eq[i,start:end]=v[j]*A[j,i]*x
    results = linprog(np.zeros(d0*d1), A_ub=Iq, b_ub=np.zeros(n*d1), A_eq=Eq, b_eq=y)
    return results.success

In [333]:
def generate_gauss_data(d0, n):
    X = np.random.randn(n,d0)
    y = np.random.randn(n, 1)
    X = torch.from_numpy(X).float()
    y = torch.from_numpy(y).float()
    return X, y

In [326]:
def sample_activation_region(d0, d1, n, data_func, d2=1, L=1):
    model = MLP(d0, d1, d2, L)
    X, y = data_func(d0, n)
    J = jacobian(model, X).detach().numpy()
    r = rank(J)
    W = model.layers[0].weight.detach()
    P =(W@X.T).detach().numpy()
    A = P>0
    W = W.numpy()
    X = X.detach().numpy().T
    v = model.last_layer.detach().numpy()
    has_global_min = check_global(A,v,X,y,d0,d1,n)
    return r, has_global_min

In [354]:
# Define number data points, size of hidden layer, data dimension and number of trials
d0=1
N = np.arange(2,21,1)
D1 = np.arange(5,100,5)
T = 10
poly = RandomPolynomialMapping(-1,1,2)
# X,y= poly.generate_random_data(d0, n)

In [None]:
av_rank_J = np.zeros((len(D1), len(N)))
av_globals = np.zeros((len(D1), len(N)))
for i in range(len(D1)):
    for j in range(len(N)):
        print("Computing average stats for d1="+str(D1[i]) + ", n=" + str(N[j]))
        for t in range(T):
            r, g = sample_activation_region(d0, D1[i], N[j], generate_gauss_data)
            av_rank_J[i,j] += r
            av_globals[i,j] += g
        av_rank_J[i,j] = 100*av_rank_J[i,j]/(T*N[j])
        av_globals[i,j] = av_globals[i,j]/T


Computing average stats for d1=5, n=2
Computing average stats for d1=5, n=3
Computing average stats for d1=5, n=4
Computing average stats for d1=5, n=5
Computing average stats for d1=5, n=6
Computing average stats for d1=5, n=7


  results = linprog(np.zeros(d0*d1), A_ub=Iq, b_ub=np.zeros(n*d1), A_eq=Eq, b_eq=y)
  results = linprog(np.zeros(d0*d1), A_ub=Iq, b_ub=np.zeros(n*d1), A_eq=Eq, b_eq=y)
  results = linprog(np.zeros(d0*d1), A_ub=Iq, b_ub=np.zeros(n*d1), A_eq=Eq, b_eq=y)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  results = linprog(np.zeros(d0*d1), A_ub=Iq, b_ub=np.zeros(n*d1), A_eq=Eq, b_eq=y)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=5, n=8
Computing average stats for d1=5, n=9
Computing average stats for d1=5, n=10
Computing average stats for d1=5, n=11
Computing average stats for d1=5, n=12
Computing average stats for d1=5, n=13
Computing average stats for d1=5, n=14
Computing average stats for d1=5, n=15
Computing average stats for d1=5, n=16
Computing average stats for d1=5, n=17
Computing average stats for d1=5, n=18
Computing average stats for d1=5, n=19
Computing average stats for d1=5, n=20
Computing average stats for d1=10, n=2


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=10, n=3
Computing average stats for d1=10, n=4
Computing average stats for d1=10, n=5
Computing average stats for d1=10, n=6
Computing average stats for d1=10, n=7
Computing average stats for d1=10, n=8
Computing average stats for d1=10, n=9
Computing average stats for d1=10, n=10
Computing average stats for d1=10, n=11
Computing average stats for d1=10, n=12
Computing average stats for d1=10, n=13
Computing average stats for d1=10, n=14
Computing average stats for d1=10, n=15
Computing average stats for d1=10, n=16
Computing average stats for d1=10, n=17
Computing average stats for d1=10, n=18
Computing average stats for d1=10, n=19
Computing average stats for d1=10, n=20
Computing average stats for d1=15, n=2
Computing average stats for d1=15, n=3
Computing average stats for d1=15, n=4
Computing average stats for d1=15, n=5
Computing average stats for d1=15, n=6
Computing average stats for d1=15, n=7


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=15, n=8
Computing average stats for d1=15, n=9
Computing average stats for d1=15, n=10
Computing average stats for d1=15, n=11
Computing average stats for d1=15, n=12
Computing average stats for d1=15, n=13
Computing average stats for d1=15, n=14
Computing average stats for d1=15, n=15
Computing average stats for d1=15, n=16
Computing average stats for d1=15, n=17
Computing average stats for d1=15, n=18
Computing average stats for d1=15, n=19
Computing average stats for d1=15, n=20
Computing average stats for d1=20, n=2


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=20, n=3


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=20, n=4
Computing average stats for d1=20, n=5
Computing average stats for d1=20, n=6
Computing average stats for d1=20, n=7
Computing average stats for d1=20, n=8


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=20, n=9
Computing average stats for d1=20, n=10
Computing average stats for d1=20, n=11
Computing average stats for d1=20, n=12
Computing average stats for d1=20, n=13
Computing average stats for d1=20, n=14
Computing average stats for d1=20, n=15
Computing average stats for d1=20, n=16
Computing average stats for d1=20, n=17
Computing average stats for d1=20, n=18
Computing average stats for d1=20, n=19
Computing average stats for d1=20, n=20
Computing average stats for d1=25, n=2
Computing average stats for d1=25, n=3


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=25, n=4
Computing average stats for d1=25, n=5
Computing average stats for d1=25, n=6
Computing average stats for d1=25, n=7
Computing average stats for d1=25, n=8


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=25, n=9
Computing average stats for d1=25, n=10
Computing average stats for d1=25, n=11
Computing average stats for d1=25, n=12
Computing average stats for d1=25, n=13
Computing average stats for d1=25, n=14
Computing average stats for d1=25, n=15
Computing average stats for d1=25, n=16
Computing average stats for d1=25, n=17
Computing average stats for d1=25, n=18
Computing average stats for d1=25, n=19
Computing average stats for d1=25, n=20
Computing average stats for d1=30, n=2


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=30, n=3
Computing average stats for d1=30, n=4
Computing average stats for d1=30, n=5
Computing average stats for d1=30, n=6
Computing average stats for d1=30, n=7
Computing average stats for d1=30, n=8
Computing average stats for d1=30, n=9
Computing average stats for d1=30, n=10
Computing average stats for d1=30, n=11
Computing average stats for d1=30, n=12
Computing average stats for d1=30, n=13
Computing average stats for d1=30, n=14
Computing average stats for d1=30, n=15
Computing average stats for d1=30, n=16
Computing average stats for d1=30, n=17
Computing average stats for d1=30, n=18
Computing average stats for d1=30, n=19
Computing average stats for d1=30, n=20
Computing average stats for d1=35, n=2
Computing average stats for d1=35, n=3


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=35, n=4
Computing average stats for d1=35, n=5
Computing average stats for d1=35, n=6
Computing average stats for d1=35, n=7
Computing average stats for d1=35, n=8


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=35, n=9
Computing average stats for d1=35, n=10
Computing average stats for d1=35, n=11
Computing average stats for d1=35, n=12
Computing average stats for d1=35, n=13
Computing average stats for d1=35, n=14
Computing average stats for d1=35, n=15
Computing average stats for d1=35, n=16
Computing average stats for d1=35, n=17
Computing average stats for d1=35, n=18
Computing average stats for d1=35, n=19
Computing average stats for d1=35, n=20
Computing average stats for d1=40, n=2


  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

Computing average stats for d1=40, n=3


In [None]:
fig, ax = plt.subplots(1,2,figsize=(12, 7))
sns.heatmap(av_rank_J, ax=ax[0], xticklabels=N, yticklabels=D1)
sns.heatmap(av_globals, ax=ax[1], xticklabels=N, yticklabels=D1)
ax[0].invert_yaxis()
ax[1].invert_yaxis()
ax[0].set_xlabel("# Data Points")
ax[1].set_xlabel("# Data Points")
ax[0].set_ylabel("Network Width")
ax[1].set_ylabel("Network Width")
ax[0].set_title("% of non-empty activation regions which are full (column) rank")
ax[1].set_title("% of non-empty activation regions have a global minimum")
plt.tight_layout()

In [229]:
d0 = 1
d1 = 100
d2 = 1
L = 1
n = 20

In [230]:
poly = RandomPolynomialMapping(-1,1,2)
X,y= poly.generate_random_data(d0, n)

In [231]:
model = MLP(d0, d1, d2, L)
J = jacobian(model, X).detach().numpy()
W = model.layers[0].weight.detach()
P =(W@X.T).detach().numpy()
A = P>0
W = W.numpy()
X = X.detach().numpy().T
v = model.last_layer.detach().numpy()

In [240]:
rank(J)

20

In [232]:
check_global(A,v,X,y)

True


0

In [None]:
B = -2*A+1
Eq = np.zeros((n, d0*d1))
Iq = np.zeros((n*d1, d0*d1))
# Form equality and inequality matrices for linear program
for i in range(n):
    x = X[:,i]
    for j in range(d1):
        start = j*d0
        end = (j+1)*d0
        Iq[(i*d0)+j,start:end]=B[j,i]*x
        Eq[i,start:end]=v[j]*A[j,i]*x
results = linprog(np.zeros(d0*d1), A_ub=Iq, b_ub=np.zeros(n*d1), A_eq=Eq, b_eq=y)
print(results.success)

In [136]:
X.shape

(2, 20)

In [87]:
W.flatten()

array([ 0.32267588,  0.326783  , -0.09220368, -0.23511627,  0.14187759,
        0.0670296 ,  0.630227  ,  0.07305872, -0.6311387 , -0.58621967],
      dtype=float32)

In [109]:
X.shape

torch.Size([20, 2])

In [79]:
A.shape

(64, 10)

In [80]:
v.shape

(64, 1)

In [110]:
X[:,i].shape

torch.Size([20])

In [218]:
B = -2*A+1
Eq = np.zeros((n, d0*d1))
Iq = np.zeros((n*d1, d0*d1))
# Form equality and inequality matrices for linear program
for i in range(n):
    x = X[:,i]
    for j in range(d1):
#         print(i, j)
        start = j*d0
        end = (j+1)*d0
        Iq[(i*d1)+j,start:end]=B[j,i]*x
        Eq[i,start:end]=v[j]*A[j,i]*x


In [219]:
results = linprog(np.zeros(d0*d1), A_ub=Iq, b_ub=np.zeros(n*d1), A_eq=Eq, b_eq=y)

In [162]:
results.success

False

In [155]:
Iq[0]

array([-1.18474662,  1.41357958,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ])

In [156]:
Iq[1]

array([ 0.        ,  0.        ,  1.18474662, -1.41357958,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ])

In [157]:
Iq

array([[-1.18474662,  1.41357958,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  1.18474662, -1.41357958,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [-1.78071713,  0.27677271,  0.        ,  0.        , -1.18474662,
         1.41357958,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  1.78071713, -0.27677271,  0.        ,
         0.        , -1.18474662,  1.41357958,  0.        ,  0.        ],
       [-1.66319883,  0.23704459,  0.        ,  0.        ,  1.78071713,
        -0.27677271,  0.        ,  0.        ,  1.18474662, -1.41357958],
       [ 0.        ,  0.        ,  1.66319883, -0.23704459,  0.        ,
         0.        , -1.78071713,  0.27677271,  0.        ,  0.        ],
       [-0.99680674, -0.10799686,  0.        ,  0.        ,  1.66319883,
        -0.23704459,  0.        ,  0.        