In [10]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


from jax import random, nn
from jax import grad, value_and_grad
import jax
import jax.numpy as jnp

In [11]:
def neuron_initialization(architecture, seed = 42):
    
    params = {}
    
    key = random.PRNGKey(seed)
    for i in range(len(architecture)-1):
        
        inputs = architecture[i]
        oputs = architecture[i+1]
        
        initializer = jax.nn.initializers.he_normal()
        #initializer(subkey, (inputs, oputs), jnp.float32) 
        
        key, subkey = random.split(key)
        params[f'w_{i}'] = initializer(subkey, (inputs, oputs), jnp.float32)  #random.uniform(subkey, shape=(inputs, oputs), minval=-1, maxval= 1)     #Weights from neuron to neuron 
        key, subkey = random.split(key)
        params[f'b_{i}'] = initializer(subkey, (1, oputs), jnp.float32)  #random.uniform(subkey, shape=(1, oputs),  minval=-1, maxval= 1)           #Bais vecor for each layer
        
    return params

In [12]:
def forward_propagation(params, x_input):
    
    # could this be removed if we utilize back propagation using jax??
    a = x_input
    
    n_layers = int(len(params)/2)
    
    for i in range(n_layers):
        w = params[f'w_{i}']
        b = params[f'b_{i}']
        a_input = a
        
        z = a_input @ w + b
       
        if i < n_layers - 1:
            a = nn.relu(z)           #general simple activation function is used
        else: #problem specific case, for classification we do softmax
            a = nn.softmax(z)
                   
    return a

In [13]:
def cross_entropy_loss(params, x_input, y_labels, lamba_lasso = 0, lambda_ridge = 0):
    y_probs = forward_propagation(params, x_input)
    log_probs = jnp.log(y_probs) 
    one_hot_labels = nn.one_hot(y_labels, y_probs.shape[-1])  # Convert to one-hot encoding, and use the y_probes dims as the num of classes
    l = -jnp.mean(jnp.sum(one_hot_labels * log_probs, axis=1))
    
    if lamba_lasso != 0:
        l+= lamba_lasso*jnp.sum(jnp.array([jnp.sum(jnp.abs(params[f'w_{i}'])) for i in range(int(len(params) / 2))]))
    if lambda_ridge != 0:
        l+= lambda_ridge*jnp.sum(jnp.array([jnp.sum(params[f'w_{i}'] ** 2) for i in range(int(len(params) / 2))]))
          
    return l

In [14]:
def schuffel_data(x, y, seed):
    key = jax.random.PRNGKey(seed) 
    permutation = jax.random.permutation(key, x.shape[0])
    return x[permutation], y[permutation]

def batch_generator(x_input, y_target, batch_size, schuffel = True, seed = 42):
    
    n_batches = int(len(y_target)/batch_size)
        
    #x_batches = jnp.zeros(shape=(n_batches, batch_size, x_input.shape[1]))
    #y_batches = jnp.zeros(shape=(n_batches, batch_size))
    
    if schuffel:     
        x_input, y_target = schuffel_data(x_input, y_target, seed)
    
    for i in range(n_batches):
        
        #x_batches = x_batches.at[i].set(x_input[i*batch_size:(i+1)*batch_size,:])
        #y_batches = y_batches.at[i].set(y_target[i*batch_size:(i+1)*batch_size])

        x_batch = jnp.array(x_input[i*batch_size:(i+1)*batch_size,:])
        y_batch = jnp.array(y_target[i*batch_size:(i+1)*batch_size])
    
        yield x_batch, y_batch
    
    #return (x_batches, y_batches)

In [15]:
def gd_parameter_update(param_grad, params, alpha):
   
    #updating parameters
    updated_params = {}
    for param in params.keys():
        updated_params[param] = params[param] - alpha*param_grad[param]
    
    return updated_params

In [16]:
def standardize_data(x):

    
    mean = np.mean(x, axis=0)
    std = np.std(x, axis=0)
    x_std = (x - mean)/(std + 0.0001)

    
    return x_std, mean, std

def data_split(x, y, split_coeff = 0.8):
    n_rows = x.shape[0]
    
    sc_size = int(split_coeff * n_rows)
    
    
    sc_x = x[:sc_size]
    sc_y = y[:sc_size]
    
    sc_inv_x = x[sc_size:]
    sc_inv_y = y[sc_size:]
    
    return sc_x, sc_y, sc_inv_x, sc_inv_y

def prepare_data(x_input, y_target, standerdize = True):


    x_train, y_train, x_test, y_test = data_split(x_input, y_target, 0.8)


    x_train_n, y_train_n, x_train_v, y_train_v = data_split(x_train, y_train, 0.8)


    if standerdize:
        
        #main traing data
        x_std_train, mean_train, std_train = standardize_data(x_train)
        x_train  = x_std_train
        x_test = (x_test - mean_train)/std_train
                
        #Validation
        x_std_train_n, mean_train_n, std_train_n = standardize_data(x_train_n)
        x_train_n = x_std_train_n
        x_train_v = (x_train_v - mean_train_n)/std_train_n

    return x_train, y_train, x_train_n, y_train_n, x_train_v, y_train_v, x_test, y_test
  


In [17]:
def predict(params, x):
    y_probs = forward_propagation(params, x)
    return np.argmax(y_probs, axis=1)

def accuracy(params, x_input, y_target):
    y_pred = predict(params, x_input)
    return np.mean(y_target == y_pred)

In [18]:
def mnist_data_load(percent = 1):
    
    train_df = pd.read_csv('../data/mnist_train.csv')
    test_df = pd.read_csv('../data/mnist_test.csv')
    
    all_data = pd.concat([train_df, test_df], ignore_index=True)
    
    n_points = len(all_data)
    
    cut_off_val = int(percent * n_points)
    
    x_df = all_data.drop('label', axis=1)
    y_df = all_data['label']
    
    x_input = x_df.to_numpy()
    y_target = y_df.to_numpy()
     
    return x_input[:cut_off_val], y_target[:cut_off_val]

In [19]:
from optimizers import adam_parameter_update

def classification_train_2(params, loss, x_input, y_target, batch_size=25, epochs=200, alpha=0.01,
                         optimizer="gd", lambda_lasso=0, lambda_ridge=0, beta1=0.9, beta2=0.999, eps=1e-8):
    
    history = {'loss_v_epoch': []}
    
    if optimizer == "adam":
        m = {key: jnp.zeros_like(value) for key, value in params.items()}
        v = {key: jnp.zeros_like(value) for key, value in params.items()}
        t = 1
    
    for i in range(epochs): 
        for x_i_batch, y_i_batch in batch_generator(x_input, y_target, batch_size, schuffel=True, seed=42):
            
            loss_i, param_grad = value_and_grad(loss, argnums=0)(params, x_i_batch, y_i_batch, lambda_lasso, lambda_ridge)
            
            if optimizer == "gd":
                params = gd_parameter_update(param_grad, params, alpha)
            elif optimizer == "adam":
                params, m, v = adam_parameter_update(param_grad, params, m, v, t, alpha, beta1, beta2, eps)
                t += 1
            else:
                raise ValueError(f"Ottimizzatore '{optimizer}' non supportato!")
           
        history['loss_v_epoch'].append(loss_i)

        print(f'Epoch {i+1}/{epochs} -> Loss: {loss_i:.4f}')
            
    return params, history

In [47]:
import itertools
import jax.numpy as jnp
from sklearn.model_selection import KFold

def hyperparameter_search(params, architecture, data, labels, k_folds):
    alphas = [0.001]
    lasso_values = [0]
    ridge_values = [0.01] 
    batch_sizes = [256]
    epochs_values = [100, 200, 300, 320]

    best_acc = 0
    best_params = {}

    kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for alpha, lambda_lasso, lambda_ridge, batch_size, epochs in itertools.product(alphas, lasso_values, ridge_values, batch_sizes, epochs_values):
        print(f"\nTesting con: alpha={alpha}, Lasso={lambda_lasso}, Ridge={lambda_ridge}, batch_size={batch_size}, epochs={epochs}")

        fold_accuracies = []
        
        for train_index, val_index in kf.split(data):
            train_data, val_data = data[train_index], data[val_index]
            train_labels, val_labels = labels[train_index], labels[val_index]
            
            params = neuron_initialization(architecture)

            trained_params, history = classification_train_2(
                params, cross_entropy_loss, train_data, train_labels,
                epochs=epochs, batch_size=batch_size, optimizer="adam",
                alpha=alpha, lambda_lasso=lambda_lasso, lambda_ridge=lambda_ridge
            )

            val_acc = accuracy(trained_params, val_data, val_labels)
            fold_accuracies.append(val_acc)

        avg_acc = jnp.mean(jnp.array(fold_accuracies))
        print(f"Average Validation Accuracy: {avg_acc:.4f}")

        if avg_acc > best_acc:
            best_acc = avg_acc
            best_params = {
                'alpha': alpha,
                'lambda_lasso': lambda_lasso,
                'lambda_ridge': lambda_ridge,
                'batch_size': batch_size,
                'epochs': epochs
            }

    print("\nBest combination:")
    print(f"alpha={best_params['alpha']}, lambda_lasso={best_params['lambda_lasso']}, lambda_ridge={best_params['lambda_ridge']}, batch_size={best_params['batch_size']}, epochs={best_params['epochs']}")
    print(f"Best Average Accuracy: {best_acc:.4f}")

    return best_params

In [None]:
import preprocessing
import augmentation

dataset_path = "../data/mfeat-pix"

np.random.seed(42)

img_shape = (16, 15)

data = preprocessing.load_data(dataset_path)

num_classes = 10
samples_per_class = 200
train_samples_per_class = 100
test_samples_per_class = 100

train_data, test_data = preprocessing.split_data(data, num_classes, samples_per_class, train_samples_per_class)

def plot_sample_images_by_class(data, num_classes=10, samples_per_class=10, img_shape=(16, 15), save_path="../visualizations/dataset_visualization.pdf"):
    plt.figure(figsize=(10, 10))
    
    samples_per_class = min(samples_per_class, len(data) // num_classes)
    
    for class_idx in range(num_classes):
        start_idx = class_idx * (len(data) // num_classes)
        class_samples = data[start_idx : start_idx + samples_per_class]
        
        for i in range(samples_per_class):
            plt.subplot(num_classes, samples_per_class, class_idx * samples_per_class + i + 1)
            plt.imshow(class_samples[i].reshape(img_shape), cmap='gray')
            plt.axis('off')

    plt.savefig(save_path, format="pdf", dpi=300, bbox_inches='tight')

    plt.show()

plot_sample_images_by_class(train_data, num_classes=10, samples_per_class=10, img_shape=img_shape)

In [45]:
augment_rotate = np.radians(12)

train_data_aug = augmentation.augment_data(train_data, img_shape, augment_rotate, num_versions=3)

print(train_data_aug.shape)

(4000, 240)


In [48]:
train_labels, test_labels = preprocessing.create_labels(num_classes, train_samples_per_class=400, test_samples_per_class=100)
print(train_labels.shape)
print(test_labels.shape)


architecture = [train_data.shape[1], 64, 64, num_classes]

print(architecture)
params = neuron_initialization(architecture)

best_params = hyperparameter_search(params, architecture, train_data_aug, train_labels, k_folds=5)

(4000,)
(1000,)
[240, 64, 64, 10]

Testing con: alpha=0.001, Lasso=0, Ridge=0.01, batch_size=256, epochs=100
Epoch 1/100 -> Loss: 4.4547
Epoch 2/100 -> Loss: 3.5420
Epoch 3/100 -> Loss: 3.0791
Epoch 4/100 -> Loss: 2.8409
Epoch 5/100 -> Loss: 2.6372
Epoch 6/100 -> Loss: 2.4851
Epoch 7/100 -> Loss: 2.3543
Epoch 8/100 -> Loss: 2.2446
Epoch 9/100 -> Loss: 2.1447
Epoch 10/100 -> Loss: 2.0558
Epoch 11/100 -> Loss: 1.9712
Epoch 12/100 -> Loss: 1.8932
Epoch 13/100 -> Loss: 1.8215
Epoch 14/100 -> Loss: 1.7541
Epoch 15/100 -> Loss: 1.6907
Epoch 16/100 -> Loss: 1.6316
Epoch 17/100 -> Loss: 1.5750
Epoch 18/100 -> Loss: 1.5225
Epoch 19/100 -> Loss: 1.4711
Epoch 20/100 -> Loss: 1.4245
Epoch 21/100 -> Loss: 1.3794
Epoch 22/100 -> Loss: 1.3370
Epoch 23/100 -> Loss: 1.2960
Epoch 24/100 -> Loss: 1.2563
Epoch 25/100 -> Loss: 1.2192
Epoch 26/100 -> Loss: 1.1839
Epoch 27/100 -> Loss: 1.1498
Epoch 28/100 -> Loss: 1.1172
Epoch 29/100 -> Loss: 1.0856
Epoch 30/100 -> Loss: 1.0557
Epoch 31/100 -> Loss: 1.0263
E

In [49]:
architecture = [train_data.shape[1], 64, 64, num_classes]

print(architecture)
params = neuron_initialization(architecture)

trained_params, history = classification_train_2(
                params, cross_entropy_loss, train_data_aug, train_labels,
                epochs=300, batch_size=256, optimizer="adam",
                alpha=0.001, lambda_ridge=0.01
            )

accuracy(trained_params, test_data, test_labels)

[240, 64, 64, 10]
Epoch 1/300 -> Loss: 4.1158
Epoch 2/300 -> Loss: 3.2390
Epoch 3/300 -> Loss: 2.8998
Epoch 4/300 -> Loss: 2.6671
Epoch 5/300 -> Loss: 2.4718
Epoch 6/300 -> Loss: 2.3129
Epoch 7/300 -> Loss: 2.1852
Epoch 8/300 -> Loss: 2.0728
Epoch 9/300 -> Loss: 1.9702
Epoch 10/300 -> Loss: 1.8753
Epoch 11/300 -> Loss: 1.7875
Epoch 12/300 -> Loss: 1.7073
Epoch 13/300 -> Loss: 1.6325
Epoch 14/300 -> Loss: 1.5627
Epoch 15/300 -> Loss: 1.4980
Epoch 16/300 -> Loss: 1.4376
Epoch 17/300 -> Loss: 1.3801
Epoch 18/300 -> Loss: 1.3261
Epoch 19/300 -> Loss: 1.2760
Epoch 20/300 -> Loss: 1.2281
Epoch 21/300 -> Loss: 1.1830
Epoch 22/300 -> Loss: 1.1409
Epoch 23/300 -> Loss: 1.0998
Epoch 24/300 -> Loss: 1.0610
Epoch 25/300 -> Loss: 1.0247
Epoch 26/300 -> Loss: 0.9901
Epoch 27/300 -> Loss: 0.9574
Epoch 28/300 -> Loss: 0.9260
Epoch 29/300 -> Loss: 0.8976
Epoch 30/300 -> Loss: 0.8686
Epoch 31/300 -> Loss: 0.8418
Epoch 32/300 -> Loss: 0.8163
Epoch 33/300 -> Loss: 0.7927
Epoch 34/300 -> Loss: 0.7692
Epoch

Array(0.975, dtype=float32)