# Smooth Monotonic Networks: Counting silent neurons

## General definitions

In [1]:
%load_ext autoreload
%autoreload 

In [2]:
import numpy as np
import random

import torch 
import torch.nn as nn

from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import r2_score as r2
from sklearn.isotonic import IsotonicRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures

import matplotlib.pyplot as plt
from tqdm.notebook import tnrange

from MonotonicNN import SmoothMonotonicNN, MonotonicNN, MonotonicNNAlt
from MonotonicNNPaperUtils import Progress, total_params, fit_torch

from monotonenorm import GroupSort, direct_norm, SigmaNet

## Univariate experiments 
Section 4.1 in the manuscript.

In [3]:
T = 21  # number of trials, odd number for having a "median trial"
ls = 75  # lattice points (k in original paper)
ls_small = 35
K = 6  # number of SMM groups, we always use H_k = K
N_train = 100  # number of examples in training data set
N_test = 1000 # number of examples in test data set
sigma = 0.01  # noise level, feel free to vary 
width_small = K
width = K+2

In [4]:
def generate1D(function_name, sigma=0., random=False, xrange=1., N=50):
    if random:
        x = np.random.rand(N) * xrange
        x = np.sort(x, axis=0)
    else:
        xstep = xrange / N
        x = np.arange(0, xrange, xstep)
    match function_name:
        case 'sigmoid10':
            y = 1. /(1. + np.exp(-(x-xrange/2.) * 10.))
        case 'sq':
            y = x**2
        case 'sqrt':
            y = np.sqrt(x)
    y = y + sigma*np.random.normal(0, 1., N)
    return x.reshape(N, 1), y

In [5]:
methods = ['monotonic_alt','smooth']
#methods = ['smooth']
tasks = ['sq', 'sqrt', 'sigmoid10']
#tasks = ['sigmoid10']
#T = 21
N_tasks = len(tasks)
N_methods = len(methods)


MSE_train = np.zeros((N_tasks, N_methods, T))
MSE_test = np.zeros((N_tasks, N_methods, T))
MSE_clip = np.zeros((N_tasks, N_methods, T))
R2_train = np.zeros((N_tasks, N_methods, T))
R2_test = np.zeros((N_tasks, N_methods, T))
X_train = np.zeros((N_tasks, T, N_train))
Y_train = np.zeros((N_tasks, T, N_train))
X_test = np.zeros((N_tasks, T, N_test))
Y_test = np.zeros((N_tasks, T, N_test))
O_test = np.zeros((N_tasks, N_methods, T, N_test))
no_params=np.zeros(N_methods)
Active = np.zeros((N_tasks, T))
Dead = np.zeros((N_tasks, T))
ActiveInit = np.zeros((N_tasks, T))
active = 0

for trial in tnrange(T):
    for task_id, task in enumerate(tasks):
        seed = task_id + trial*N_tasks
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        x_train, y_train = generate1D(task, sigma=sigma, random=True, N=N_train)
        x_test, y_test   = generate1D(task, sigma=0., random=False, N=N_test)
        X_test[task_id, trial] = x_test.reshape(-1)
        Y_test[task_id, trial] = y_test
        X_train[task_id, trial] = x_train.reshape(-1)
        Y_train[task_id, trial] = y_train
        x_train_torch = torch.from_numpy(x_train.astype(np.float32)).clone()
        y_train_torch = torch.from_numpy(y_train.astype(np.float32)).clone()
        x_test_torch = torch.from_numpy(x_test.astype(np.float32)).clone()
        y_test_torch = torch.from_numpy(y_test.astype(np.float32)).clone()

        for method_id, method in enumerate(methods):
            match method:
                case 'smooth':
                    model = SmoothMonotonicNN(1, K, K, beta=2.)
                    if(trial+task_id==0):
                        no_params[method_id] = total_params(model)
                        print(method, total_params(model), "parameters")
                    fit_torch(model, x_train_torch, y_train_torch)
                    y_pred_train = model(x_train_torch).detach().numpy()
                    y_pred_test = model(x_test_torch).detach().numpy()
                    
                    model.zero_grad()
                    sum_y = torch.sum(model(x_test_torch))
                    sum_y.backward()
                    dead = model.check_grad_neuron()
                    Dead[task_id, trial] = dead
             
                case 'monotonic_alt':
                    model = MonotonicNNAlt(1, K, K)
                    model.reset_active_max()
                    y_pred_test = model(x_test_torch).detach().numpy()
                    activeInit, _ = model.active_max()
                    if(trial+task_id==0):
                        no_params[method_id] = total_params(model)
                        print(method, total_params(model), "parameters")
                    fit_torch(model, x_train_torch, y_train_torch)
                    y_pred_train = model(x_train_torch).detach().numpy()
                    model.reset_active_max()
                    y_pred_test = model(x_test_torch).detach().numpy()
                    active, _ = model.active_max()
                    
                    Active[task_id, trial] = active
                    ActiveInit[task_id, trial] = activeInit

            MSE_train[task_id, method_id, trial] = mse(y_train, y_pred_train)
            MSE_test[task_id, method_id, trial] = mse(y_test, y_pred_test)
            


  0%|          | 0/21 [00:00<?, ?it/s]

monotonic_alt 73 parameters
smooth 73 parameters
tensor(0)
tensor(5)
tensor(2)
tensor(0)
tensor(0)
tensor(22)
tensor(3)
tensor(10)
tensor(2)
tensor(10)
tensor(0)
tensor(10)
tensor(5)
tensor(14)
tensor(4)
tensor(20)
tensor(0)
tensor(1)
tensor(1)
tensor(0)
tensor(0)
tensor(5)
tensor(0)
tensor(10)
tensor(2)
tensor(10)
tensor(5)
tensor(2)
tensor(0)
tensor(9)
tensor(1)
tensor(0)
tensor(4)
tensor(3)
tensor(0)
tensor(6)
tensor(2)
tensor(0)
tensor(0)
tensor(6)
tensor(7)
tensor(10)
tensor(6)
tensor(0)
tensor(12)
tensor(4)
tensor(0)
tensor(2)
tensor(7)
tensor(0)
tensor(6)
tensor(0)
tensor(0)
tensor(0)
tensor(17)
tensor(0)
tensor(18)
tensor(7)
tensor(0)
tensor(5)
tensor(2)
tensor(0)
tensor(1)


In [6]:
functions = ("\\fsq", "\\fsqrt", "\\fsig")
for f_id, f_name in enumerate(functions):
    print(f_name, end=' & ')
    print(int(ActiveInit.min(axis=1)[f_id]), end=' & ')
    print("{:.2f}".format((np.mean(ActiveInit, axis=1)[f_id])), end=' & ')
    print(int(ActiveInit.max(axis=1)[f_id]), end=' & ')
    print(int(Active.min(axis=1)[f_id]), end=' & ')
    print("{:.2f}".format((np.mean(Active, axis=1)[f_id])), end=' & ')
    print(int(Active.max(axis=1)[f_id]), end=' ')
    print("\\\\")
    
print()

n_neurons = K*K
for f_id, f_name in enumerate(functions):
    print(f_name, end=' & ')
    print(n_neurons - int(Dead.max(axis=1)[f_id]), end=' & ')
    print("{:.2f}".format(n_neurons - (np.mean(Dead, axis=1)[f_id])), end=' & ')
    print(n_neurons - int(Dead.min(axis=1)[f_id]), end=' ')
    print("\\\\")
    
print()

for f_id, f_name in enumerate(functions):
    print(f_name, end=' & ')
    print(int(ActiveInit.min(axis=1)[f_id]), end=' & ')
    print("{:.1f}".format((np.mean(ActiveInit, axis=1)[f_id])), end=' & ')
    print(int(ActiveInit.max(axis=1)[f_id]), end=' & ')
    print(int(Active.min(axis=1)[f_id]), end=' & ')
    print("{:.1f}".format((np.mean(Active, axis=1)[f_id])), end=' & ')
    print(int(Active.max(axis=1)[f_id]), end=' & ')
    print(n_neurons - int(Dead.max(axis=1)[f_id]), end=' & ')
    print("{:.1f}".format(n_neurons - (np.mean(Dead, axis=1)[f_id])), end=' & ')
    print(n_neurons - int(Dead.min(axis=1)[f_id]), end=' ')
    print("\\\\")
print("overall & ", end='')
print(int(ActiveInit.min()), end=' & ')
print("{:.1f}".format(np.mean(ActiveInit)), end=' & ')
print(int(ActiveInit.max()), end=' & ')
print(int(Active.min()), end=' & ')
print("{:.1f}".format(np.mean(Active)), end=' & ')
print(int(Active.max()), end=' & ')
print(n_neurons - int(Dead.max()), end=' & ')
print("{:.1f}".format(n_neurons - np.mean(Dead)), end=' & ')
print(n_neurons - int(Dead.min()), end=' ')
print("\\\\")

    
    

\fsq & 1 & 3.43 & 6 & 2 & 3.43 & 5 \\
\fsqrt & 2 & 3.86 & 7 & 1 & 2.14 & 4 \\
\fsig & 1 & 3.67 & 7 & 2 & 2.95 & 5 \\

\fsq & 16 & 31.10 & 36 \\
\fsqrt & 22 & 33.81 & 36 \\
\fsig & 14 & 29.86 & 36 \\

\fsq & 1 & 3.4 & 6 & 2 & 3.4 & 5 & 16 & 31.1 & 36 \\
\fsqrt & 2 & 3.9 & 7 & 1 & 2.1 & 4 & 22 & 33.8 & 36 \\
\fsig & 1 & 3.7 & 7 & 2 & 3.0 & 5 & 14 & 29.9 & 36 \\
overall & 1 & 3.7 & 7 & 1 & 2.8 & 5 & 14 & 31.6 & 36 \\
