In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


from jax import random, nn
from jax import grad, value_and_grad
import jax
import jax.numpy as jnp

In [3]:
def neuron_initialization(architecture, seed = 42):
    
    params = {}
    
    key = random.PRNGKey(seed)
    for i in range(len(architecture)-1):
        
        inputs = architecture[i]
        oputs = architecture[i+1]
        
        initializer = jax.nn.initializers.he_normal()
        #initializer(subkey, (inputs, oputs), jnp.float32) 
        
        key, subkey = random.split(key)
        params[f'w_{i}'] = initializer(subkey, (inputs, oputs), jnp.float32)  #random.uniform(subkey, shape=(inputs, oputs), minval=-1, maxval= 1)     #Weights from neuron to neuron 
        key, subkey = random.split(key)
        params[f'b_{i}'] = initializer(subkey, (1, oputs), jnp.float32)  #random.uniform(subkey, shape=(1, oputs),  minval=-1, maxval= 1)           #Bais vecor for each layer
        
    return params 
    
architecture = [3,4,1,2,6,10]
test_params = neuron_initialization(architecture)
test_params


{'w_0': Array([[ 0.53359914,  0.7004871 , -0.79396784, -0.55923206],
        [-1.0524203 , -0.72886986, -0.4186612 , -1.0745062 ],
        [-0.15655585, -0.04356752, -0.36388826, -0.34793305]],      dtype=float32),
 'b_0': Array([[-0.32341576, -2.0128593 , -0.06906059, -1.7266749 ]], dtype=float32),
 'w_1': Array([[-0.30228475],
        [-0.32593513],
        [-0.8055143 ],
        [-1.1263516 ]], dtype=float32),
 'b_1': Array([[1.0258774]], dtype=float32),
 'w_2': Array([[ 1.8166499, -1.2137247]], dtype=float32),
 'b_2': Array([[ 1.238835 , -1.5672728]], dtype=float32),
 'w_3': Array([[-1.0958    , -0.34511802,  1.9544435 ,  0.6149077 ,  0.29602355,
         -0.5394382 ],
        [ 0.14415586,  1.1483663 ,  1.1170226 , -0.8584173 ,  0.61064976,
          1.0186568 ]], dtype=float32),
 'b_3': Array([[-0.7017356 , -2.5614824 ,  0.14467148, -1.7404512 ,  0.8899914 ,
          1.2486326 ]], dtype=float32),
 'w_4': Array([[ 0.38829002, -0.33662477,  0.08499774, -0.940721  , -0.13233559,
  

In [4]:
def forward_propagation(params, x_input):
    
    # could this be removed if we utilize back propagation using jax??
    a = x_input
    
    n_layers = int(len(params)/2)
    
    for i in range(n_layers):
        w = params[f'w_{i}']
        b = params[f'b_{i}']
        a_input = a
        
        z = a_input @ w + b
       
        if i < n_layers - 1:
            a = nn.relu(z)           #general simple activation function is used
        else: #problem specific case, for classification we do softmax
            a = nn.softmax(z)
            
        
    return a


architecture = [2,4,1,2,6,2]
test_params = neuron_initialization(architecture)
test_input = jnp.array([1,1])

a = forward_propagation(test_params, test_input)
a

Array([[0.8991749 , 0.10082512]], dtype=float32)

In [5]:
def cross_entropy_loss(params, x_input, y_labels, lamba_lasso = 0, lambda_ridge = 0):
    y_probs = forward_propagation(params, x_input)
    log_probs = jnp.log(y_probs) 
    one_hot_labels = nn.one_hot(y_labels, y_probs.shape[-1])  # Convert to one-hot encoding, and use the y_probes dims as the num of classes
    l = -jnp.mean(jnp.sum(one_hot_labels * log_probs, axis=1))
    
    if lamba_lasso != 0:
        l+= lamba_lasso*jnp.sum(jnp.array([jnp.sum(jnp.abs(params[f'w_{i}'])) for i in range(int(len(params) / 2))]))
    if lambda_ridge != 0:
        l+= lambda_ridge*jnp.sum(jnp.array([jnp.sum(params[f'w_{i}'] ** 2) for i in range(int(len(params) / 2))]))
        
    
    return l

architecture = [2,4,1,2,6,2]
test_params = neuron_initialization(architecture)
test_input = jnp.array([[1,1], [2,2]])
test_labesl = jnp.array([0,1])


print(cross_entropy_loss(test_params, test_input, test_labesl))

grad(cross_entropy_loss, argnums=0)(test_params, test_input, test_labesl)

loss_value, Wb_grad = value_and_grad(cross_entropy_loss, argnums=0)(test_params, test_input, test_labesl)
print('loss value', loss_value)
print('grad loss value', Wb_grad)

1.2003227
loss value 1.2003227
grad loss value {'b_0': Array([[0., 0., 0., 0.]], dtype=float32), 'b_1': Array([[0.0897226]], dtype=float32), 'b_2': Array([[0.04938904, 0.        ]], dtype=float32), 'b_3': Array([[ 0.        ,  0.        , -0.12927216,  0.36489886,  0.26236102,
         0.        ]], dtype=float32), 'b_4': Array([[ 0.39917496, -0.3991749 ]], dtype=float32), 'w_0': Array([[0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32), 'w_1': Array([[0.],
       [0.],
       [0.],
       [0.]], dtype=float32), 'w_2': Array([[0.0506671, 0.       ]], dtype=float32), 'w_3': Array([[ 0.       ,  0.       , -0.4010662,  1.1320969,  0.8139738,
         0.       ],
       [ 0.       ,  0.       ,  0.       ,  0.       ,  0.       ,
         0.       ]], dtype=float32), 'w_4': Array([[ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 2.4782069 , -2.4782066 ],
       [ 0.06678068, -0.06678067],
       [ 0.7218692 , -0.7218691 ],
       [ 0.        ,  0.        ]],

In [6]:
def schuffel_data(x, y, seed):
    key = jax.random.PRNGKey(seed) 
    permutation = jax.random.permutation(key, x.shape[0])
    return x[permutation], y[permutation]

def batch_generator(x_input, y_target, batch_size, schuffel = True, seed = 42):
    
    n_batches = int(len(y_target)/batch_size)
        
    #x_batches = jnp.zeros(shape=(n_batches, batch_size, x_input.shape[1]))
    #y_batches = jnp.zeros(shape=(n_batches, batch_size))
    
    if schuffel:     
        x_input, y_target = schuffel_data(x_input, y_target, seed)
    
    for i in range(n_batches):
        
        #x_batches = x_batches.at[i].set(x_input[i*batch_size:(i+1)*batch_size,:])
        #y_batches = y_batches.at[i].set(y_target[i*batch_size:(i+1)*batch_size])

        x_batch = jnp.array(x_input[i*batch_size:(i+1)*batch_size,:])
        y_batch = jnp.array(y_target[i*batch_size:(i+1)*batch_size])
    
        yield x_batch, y_batch
    
    #return (x_batches, y_batches)

x_input = jnp.array([[1, 10], [78, 5], [3, 7], [8, 1]]) 
y_target = jnp.array([1, 0, 1, 0])
batch_size = 2

for x_ij, y_ij in batch_generator(x_input, y_target, batch_size):
    print(x_ij)
    print(y_ij)
    print('-----')

[[3 7]
 [8 1]]
[1 0]
-----
[[ 1 10]
 [78  5]]
[1 0]
-----


In [7]:
def gd_parameter_update(param_grad, params, alpha):
   
    #updating parameters
    updated_params = {}
    for param in params.keys():
        updated_params[param] = params[param] - alpha*param_grad[param]
    
    return updated_params

In [8]:
def standardize_data(x):

    
    mean = np.mean(x, axis=0)
    std = np.std(x, axis=0)
    x_std = (x - mean)/(std + 0.0001)

    
    return x_std, mean, std

def data_split(x, y, split_coeff = 0.8):
    n_rows = x.shape[0]
    
    sc_size = int(split_coeff * n_rows)
    
    
    sc_x = x[:sc_size]
    sc_y = y[:sc_size]
    
    sc_inv_x = x[sc_size:]
    sc_inv_y = y[sc_size:]
    
    return sc_x, sc_y, sc_inv_x, sc_inv_y

def prepare_data(x_input, y_target, standerdize = True):


    x_train, y_train, x_test, y_test = data_split(x_input, y_target, 0.8)


    x_train_n, y_train_n, x_train_v, y_train_v = data_split(x_train, y_train, 0.8)


    if standerdize:
        
        #main traing data
        x_std_train, mean_train, std_train = standardize_data(x_train)
        x_train  = x_std_train
        x_test = (x_test - mean_train)/std_train
                
        #Validation
        x_std_train_n, mean_train_n, std_train_n = standardize_data(x_train_n)
        x_train_n = x_std_train_n
        x_train_v = (x_train_v - mean_train_n)/std_train_n

    return x_train, y_train, x_train_n, y_train_n, x_train_v, y_train_v, x_test, y_test
  


In [9]:
def predict(params, x):
    y_probs = forward_propagation(params, x)
    return np.argmax(y_probs, axis=1)

def accuracy(params, x_input, y_target):
    y_pred = predict(params, x_input)
    return np.mean(y_target == y_pred)

In [10]:
def mnist_data_load(percent = 1):
    

    
    train_df = pd.read_csv('../data/mnist_train.csv')
    test_df = pd.read_csv('../data/mnist_test.csv')
    
    all_data = pd.concat([train_df, test_df], ignore_index=True)
    
    n_points = len(all_data)
    
    cut_off_val = int(percent * n_points)
    
    x_df = all_data.drop('label', axis=1)
    y_df = all_data['label']
    
    x_input = x_df.to_numpy()
    y_target = y_df.to_numpy()
    
    
    
    return x_input[:cut_off_val], y_target[:cut_off_val]


In [11]:
def classification_train(params, loss, x_input, y_target, batch_size = 25, epochs = 200, alpha = 0.01):
    history = {'loss_v_epoch': [], 'accuracy_v_epoch': []}
    for i in range(epochs): 
        j = 0
        for x_i_batch, y_i_batch in batch_generator(x_input, y_target, batch_size, schuffel = True, seed = 42): #we go over each batch
            loss_i, param_grad = value_and_grad(loss, argnums=0)(params, x_i_batch, y_i_batch)
            params = gd_parameter_update(param_grad, params, alpha)
           

        history['loss_v_epoch'].append(loss_i)

        print(f'Epoch {i} -> loss: {loss_i}')

            
    return params,  history


def classification_task(architecture, loss_func, x_input, y_tartget, batch_size = 50, epochs = 10, alpha = 0.01):
    #x_input is the numpy array, which then gets convereted to jax in the batching porcces
    params = neuron_initialization(architecture)
    
    #data_proccesing
    x_train, y_train, x_train_n, y_train_n, X_train_v, y_train_v, x_test, y_test = prepare_data
    
    #loss construction
    
    
    
    #hyperparameter selection
    
    
    #Training on full set with choosen hyperparameters
    
    params = classification_train(params, loss, x_train, y_train)
    
    

    
    
    #loggic/plotting info
    
    
    pass

In [None]:
x_input, y_target = mnist_data_load(0.1)
x_input = x_input/255


In [None]:
input_dim = x_input.shape[1]
output_dim = 10 
architecture = [input_dim,  1000, 1000, output_dim]


params = neuron_initialization(architecture)


x_train, y_train, x_test, y_test = data_split(x_input, y_target, split_coeff = 0.8)
#x_train, mean, std = standardize_data(x_train)
#print(std)
#x_test = (x_test - mean)/(std + 0.0001)


param_trained, history = classification_train(params, cross_entropy_loss, x_train, y_train)

accuracy(param_trained, x_test, y_test)

In [None]:
plt.plot(np.linspace(1,200,200), history['loss_v_epoch'])
plt.show()

In [23]:
from optimizers import adam_parameter_update

def classification_train_2(params, loss, x_input, y_target, batch_size=25, epochs=200, alpha=0.01,
                         optimizer="gd", lambda_lasso=0, lambda_ridge=0, beta1=0.9, beta2=0.999, eps=1e-8):
    
    history = {'loss_v_epoch': []}
    
    if optimizer == "adam":
        m = {key: jnp.zeros_like(value) for key, value in params.items()}
        v = {key: jnp.zeros_like(value) for key, value in params.items()}
        t = 1
    
    for i in range(epochs): 
        for x_i_batch, y_i_batch in batch_generator(x_input, y_target, batch_size, schuffel=True, seed=42):
            
            loss_i, param_grad = value_and_grad(loss, argnums=0)(params, x_i_batch, y_i_batch, lambda_lasso, lambda_ridge)
            
            if optimizer == "gd":
                params = gd_parameter_update(param_grad, params, alpha)
            elif optimizer == "adam":
                params, m, v = adam_parameter_update(param_grad, params, m, v, t, alpha, beta1, beta2, eps)
                t += 1
            else:
                raise ValueError(f"Ottimizzatore '{optimizer}' non supportato!")
           
        history['loss_v_epoch'].append(loss_i)

        print(f'Epoch {i+1}/{epochs} -> Loss: {loss_i:.4f}')
            
    return params, history


In [28]:
import itertools
import jax.numpy as jnp

def hyperparameter_search(params, train_data, train_labels, test_data, test_labels):
    alphas = [0.001, 0.0005, 0.0001]
    lasso_values = [0.01, 0.001, 0.0001]
    ridge_values = [0.1, 0.01, 0.001] 
    batch_sizes = [50, 100]

    best_acc = 0
    best_params = {}

    for alpha, lambda_lasso, lambda_ridge, batch_size in itertools.product(alphas, lasso_values, ridge_values, batch_sizes):
        print(f"\nTesting con: alpha={alpha}, Lasso={lambda_lasso}, Ridge={lambda_ridge}, batch_size={batch_size}")

        params = neuron_initialization(architecture)

        trained_params, history = classification_train_2(
            params, cross_entropy_loss, train_data, train_labels,
            epochs=500, batch_size=batch_size, optimizer="adam",
            alpha=alpha, lambda_lasso=lambda_lasso, lambda_ridge=lambda_ridge
        )

        test_acc = accuracy(trained_params, test_data, test_labels)
        print(f"Test Accuracy: {test_acc:.4f}")

        if test_acc > best_acc:
            best_acc = test_acc
            best_params = {'alpha': alpha, 'lambda_lasso': lambda_lasso, 'lambda_ridge': lambda_ridge, 'batch_size': batch_size}

    print("\nMiglior combinazione trovata:")
    print(f"alpha={best_params['alpha']}, lambda_lasso={best_params['lambda_lasso']}, lambda_ridge={best_params['lambda_ridge']}, batch_size={best_params['batch_size']}")
    print(f"Best Accuracy: {best_acc:.4f}")

    return best_params


In [29]:
import preprocessing
dataset_path = "../data/mfeat-pix"

np.random.seed(42)



img_shape = (16, 15)

data = preprocessing.load_data(dataset_path)

num_classes = 10
samples_per_class = 200
train_samples_per_class = 100
test_samples_per_class = 100

train_data, test_data = preprocessing.split_data(data, num_classes, samples_per_class, train_samples_per_class)


train_labels, test_labels = preprocessing.create_labels(num_classes, train_samples_per_class, test_samples_per_class)



architecture = [train_data.shape[1], 64, num_classes]

print(architecture)
params = neuron_initialization(architecture)

best_params = hyperparameter_search(params, train_data, train_labels, test_data, test_labels)

accuracy(param_trained, test_data, test_labels)

[240, 64, 10]

Testing con: alpha=0.001, Lasso=0.01, Ridge=0.1, batch_size=50
Epoch 1/500 -> Loss: 26.4251
Epoch 2/500 -> Loss: 22.4840
Epoch 3/500 -> Loss: 19.3641
Epoch 4/500 -> Loss: 16.6284
Epoch 5/500 -> Loss: 14.2622
Epoch 6/500 -> Loss: 12.2299
Epoch 7/500 -> Loss: 10.5038
Epoch 8/500 -> Loss: 9.0317
Epoch 9/500 -> Loss: 7.7801
Epoch 10/500 -> Loss: 6.7184
Epoch 11/500 -> Loss: 5.8274
Epoch 12/500 -> Loss: 5.0950
Epoch 13/500 -> Loss: 4.4894
Epoch 14/500 -> Loss: 3.9851
Epoch 15/500 -> Loss: 3.5588
Epoch 16/500 -> Loss: 3.2074
Epoch 17/500 -> Loss: 2.9146
Epoch 18/500 -> Loss: 2.6696
Epoch 19/500 -> Loss: 2.4688
Epoch 20/500 -> Loss: 2.2985
Epoch 21/500 -> Loss: 2.1594
Epoch 22/500 -> Loss: 2.0459
Epoch 23/500 -> Loss: 1.9536
Epoch 24/500 -> Loss: 1.8756
Epoch 25/500 -> Loss: 1.8101
Epoch 26/500 -> Loss: 1.7535
Epoch 27/500 -> Loss: 1.7055
Epoch 28/500 -> Loss: 1.6646
Epoch 29/500 -> Loss: 1.6305
Epoch 30/500 -> Loss: 1.6000
Epoch 31/500 -> Loss: 1.5759
Epoch 32/500 -> Loss: 1.5

Array(0.95400006, dtype=float32)