In [5]:
from collections import defaultdict
from datetime import datetime
import math
import sys

from cvxopt import matrix, solvers
solvers.options['show_progress'] = False
import matplotlib.pyplot as plt
import numpy as np
import scipy
import torch
import torch.nn as nn

np.set_printoptions(threshold=sys.maxsize)

In [6]:
class TwoLayerNet(nn.Module):
    
    def __init__(self, d0, d1, d2, freeze=False):
        super(TwoLayerNet, self).__init__()
        
        layers = []
        
        lin_layer1 = nn.Linear(d0, d1)        
        torch.nn.init.normal_(lin_layer1.bias, mean=0., std=np.sqrt(2. / d0))        
        torch.nn.init.kaiming_normal_(lin_layer1.weight, nonlinearity='relu')
        if freeze:
            lin_layer1.bias.requires_grad = False
            lin_layer1.weight.requires_grad = False
        layers.append(lin_layer1)
        layers.append(nn.ReLU())
        
        lin_layer2 = nn.Linear(d1, d2, bias=False)
        torch.nn.init.normal_(lin_layer2.weight, mean=0., std=np.sqrt(1. / d1))
        # Freeze the weights in the last layer
        lin_layer2.weight.requires_grad = False
        layers.append(lin_layer2)
        
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
def get_gaussian_data(d0, data_size, target_fn):
    x = torch.tensor(np.random.normal(size=(data_size, d0)), dtype=torch.float)
    y = target_fn(x)
    return x, y

def get_A(model, x):
    return model.layers[0](x).detach().numpy() > 0

def contains_min(model, x, y):    
    N = x.size()[0]
    d1 = len(model.layers[0].weight)
    
    out1 = model.layers[0](x).detach().numpy()    
    pattern = out1 > 0
    
    w = model.layers[0].weight.detach().numpy().astype(np.float64)
    b = model.layers[0].bias.detach().numpy().astype(np.float64)
    v = model.layers[2].weight.detach().numpy().astype(np.float64)
    x_np = x.detach().numpy().astype(np.float64)
    y_np = y.detach().numpy().astype(np.float64)
    
    # Check the number of dead neurons
    false_dict = defaultdict(int)
    for x_pattern in pattern:
        for pi, p in enumerate(x_pattern):
            if not p:
                false_dict[pi] += 1
    dead_id = [ni for ni in false_dict if false_dict[ni] == N]
    
    # Construct x for the linear regression problem
    alive_v = np.asarray([[one_v for vi, one_v in enumerate(v[0]) if vi not in dead_id]])
    alive_pattern = np.asarray([
        [one_p for pi, one_p in enumerate(p_row) if pi not in dead_id] for p_row in pattern])
    masked_v = np.concatenate([alive_v for _ in range(N)])
    masked_v[np.invert(alive_pattern)] = 0.
     
    masked_vx = alive_v * x_np
    masked_vx[np.invert(alive_pattern)] = 0.
    
    x_tilde = np.concatenate((masked_v, masked_vx), axis=1)
    
    ########################################################
    ########################################################
    ########################################################
    # Find the quadratic problem solution
    P = matrix(x_tilde.T @ x_tilde)
    q = matrix(- x_tilde.T @ y_np)
    param_num = P.size[1]  
    
    G = np.zeros((N * param_num // 2, param_num))
    for xi, x_pattern in enumerate(alive_pattern):
        for pi, param_pattern in enumerate(x_pattern):
            # wx + b > 0
            if param_pattern:
                G[xi * param_num // 2 + pi][pi] = -1
                G[xi * param_num // 2 + pi][param_num // 2 + pi] = -x_np[xi]
            # wx + b <= 0
            else:
                G[xi * param_num // 2 + pi][pi] = 1
                G[xi * param_num // 2 + pi][param_num // 2 + pi] = x_np[xi]
    G = matrix(G)
    h = matrix(np.zeros(N * param_num // 2))
    
    beta_hat = np.array(solvers.qp(P, q, G, h)['x'])
    ########################################################
    ########################################################
    ########################################################
    
#     # Find the linear regression solution
#     beta_hat, _, _, _ = np.linalg.lstsq(x_tilde, y_np, rcond=None)
        
    # Check loss
    pred_y = x_tilde @ beta_hat    
    loss = np.mean((pred_y - y_np)**2)
    zero_loss = np.isclose(loss, 0)
        
    # Check activation pattern of the found solution
    new_weight = w.copy()
    new_bias = b.copy()
    dead_count = 0
    for ni in range(d1):
        if ni not in dead_id:
            new_bias[ni] = beta_hat[ni - dead_count][0]
            new_weight[ni] = beta_hat[d1 - len(dead_id) + ni - dead_count]
        else:
            dead_count += 1
    
    new_out = x_np @ new_weight.T + new_bias   
    new_pattern = new_out > 0
    
    same_pattern = tuple(pattern.reshape(-1)) == tuple(new_pattern.reshape(-1))
        
    parameter_dim = len(model.layers[0].weight) * 2
    eq_num = np.sum(np.isclose(np.min(np.abs(out1), axis=-1), 0))
    region_dim = parameter_dim - eq_num
        
    return loss, zero_loss, same_pattern, region_dim, new_pattern

In [None]:
RUNS_NUM = 100

d1_arr = [100 * (i+ 1) for i in range(5)]#20)]
data_size_arr = [100 * (i+ 1) for i in range(5)]#20)]

total_zero_loss = []
total_same_pattern = []
total_region_dim = []

for d1 in d1_arr:
    print(f'!!! d1: {d1}')
    d1_zero_loss = []
    d1_same_pattern = []
    d1_region_dim = []
    for data_size in data_size_arr:
        print(f'!!! data_size: {data_size}')
        teacher_net = TwoLayerNet(d0=1, d1=d1, d2=1, freeze=True)
        teacher_net.train(False)
        x, y = get_gaussian_data(d0=1, data_size=data_size, target_fn=teacher_net)
        
        original_pattern_arr = []
        same_pattern_arr = np.asarray([False for _ in range(RUNS_NUM)])
        region_dim_arr = np.zeros(RUNS_NUM)
        zero_loss_arr = np.asarray([False for _ in range(RUNS_NUM)])
        lr_pattern_arr = []

        run_id = 0
        while len(original_pattern_arr) < RUNS_NUM:
            if (run_id + 1) % 100 == 0:
                print(f'=== Run {run_id + 1}/{RUNS_NUM} ===')
            student_net = TwoLayerNet(d0=1, d1=d1, d2=1) 
            pattern_hash = hash(tuple(get_A(student_net, x).reshape(-1)))
            if pattern_hash not in original_pattern_arr:
                original_pattern_arr.append(pattern_hash)
                (_, zero_loss_arr[run_id], same_pattern_arr[run_id],
                 region_dim_arr[run_id], lr_pattern) = contains_min(student_net, x, y)
                lr_pattern_arr.append(hash(tuple(lr_pattern.reshape(-1))))
                run_id +=1

        print(f'Number of global minima: {np.sum(zero_loss_arr)}/{RUNS_NUM}')
        print(f'Number of same patterns: {np.sum(same_pattern_arr)}/{RUNS_NUM}')
        print(f'Unique lr patterns: {np.unique(lr_pattern_arr).shape[0]}/{RUNS_NUM}')
        print(f'Average region dimension: {np.mean(region_dim_arr)}')
        print()
        
        d1_zero_loss.append(np.sum(zero_loss_arr) / RUNS_NUM * 100)
        d1_same_pattern.append(np.sum(same_pattern_arr) / RUNS_NUM * 100)
        d1_region_dim.append(np.mean(region_dim_arr))
        
    total_zero_loss.append(d1_zero_loss)
    total_same_pattern.append(d1_same_pattern)
    total_region_dim.append(d1_region_dim)
    
total_zero_loss = np.asarray(total_zero_loss)
print(f'total_zero_loss:\n{total_zero_loss.shape}')
print(f'total_zero_loss:\n{total_zero_loss}')
print(f'total_same_pattern:\n{total_same_pattern}')
print(f'total_region_dim:\n{total_region_dim}')

!!! d1: 100
!!! data_size: 100
=== Run 100/100 ===
Number of global minima: 0/100
Number of same patterns: 100/100
Unique lr patterns: 100/100
Average region dimension: 200.0

!!! data_size: 200
=== Run 100/100 ===
Number of global minima: 0/100
Number of same patterns: 100/100
Unique lr patterns: 100/100
Average region dimension: 200.0

!!! data_size: 300


In [None]:
STEP = 5

colormap_arr = ['RdYlBu_r']

for colormap in colormap_arr:
    timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    fig = plt.figure(figsize=(10.7, 8), dpi=100)
    ax = fig.add_subplot(111)
    ax.tick_params(axis='both', which='major', labelsize=36)
    ax.tick_params(axis='both', which='minor', labelsize=36)

    plt.xlabel('Data size', size=40)
    plt.ylabel('Network width', size=40)
    plt.margins(x=0)

    plt.xticks(list(range(0, len(data_size_arr) , STEP)) + [len(data_size_arr) - 1],
               [d for di, d in enumerate(data_size_arr) if di % STEP == 0] + [data_size_arr[-1]])
    plt.yticks(list(range(0, len(d1_arr), STEP)) + [len(d1_arr) - 1],
               [d for di, d in enumerate(d1_arr) if di % STEP == 0] + [d1_arr[-1]])

    cp = plt.imshow(total_zero_loss, cmap=colormap, origin='lower', interpolation='nearest')
    cbar = fig.colorbar(cp)

    cbar.ax.tick_params(labelsize=36)
    plt.tight_layout()

    plt.savefig(f'images/random_local_or_global/{timestamp}_global_percentage_{colormap}.png')