### Train a neural network to predict scores

In [None]:
import jax
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, pmap, grad
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader,Dataset
import os
from functools import reduce


### Initialize weights and biases in NNet

In [None]:
def init_MLP(layer_widths, parent_key, scale=0.01):

    params = []
    keys = jax.random.split(parent_key, num=len(layer_widths)-1)

    for in_width,out_width,key in zip(layer_widths[:-1],layer_widths[1:],keys):
        weight_key, bias_key = jax.random.split(key)
        params.append([
            scale*jax.random.normal(weight_key, shape=(out_width,in_width)),
            scale*jax.random.normal(bias_key,shape=(out_width,))
        ])

    return params        

### Return predictions from NNet

In [None]:
def predict_tanh(params,x):
    hidden_layers = params[:-1]
    
    activation = x
    for w,b in hidden_layers:
        activation = jax.nn.tanh(jnp.dot(w,activation)+b)

    w_last, b_last = params[-1]
    result = jnp.dot(w_last,activation) + b_last

    return result

def predict_relu(params,x):
    hidden_layers = params[:-1]
    
    activation = x
    for w,b in hidden_layers:
        activation = jax.nn.relu(jnp.dot(w,activation)+b)

    w_last, b_last = params[-1]
    result = jnp.dot(w_last,activation) + b_last

    return result

batched_tanh_predict, batched_relu_predict = vmap(predict_tanh,in_axes=(None,0)),vmap(predict_relu,in_axes=(None,0))


### Transform functions to put into arrays and to standardize target variables

In [None]:
def cus_transform(x):
    tran = np.ravel(np.array(x,dtype=np.float64))
    return tran

def cus_collate(batch):
    transposed_data = list(zip(*batch))
    targets = np.array(transposed_data[1])
    features = np.array(transposed_data[0])
    return features, targets

def cus_target_transform(y,mean=None,std=None):
    if not mean:
        mean, std = np.mean(y),np.std(y)
    return (y-mean)/std, mean, std

### Define custom dataset for simple batching

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dir, csv_file, transform=None,target_transform=None,test_transform=None):
        self.features = pd.read_csv(os.path.join(dir,csv_file)).to_numpy()
        target = csv_file.replace('X','y')
        self.target = pd.read_csv(os.path.join(dir,target)).to_numpy()
        self.transform = transform
        self.target_transform = target_transform
        self.test_transform = test_transform
        if self.target_transform and not self.test_transform:
            target,mean,std = self.target_transform(self.target)
            self.target = target
            self.scaler = (mean,std)
        if self.test_transform:
            target,mean,std = self.target_transform(self.target,mean=self.test_transform[0],std=self.test_transform[1])
            self.target = target

    def __len__(self):
        return len(self.target)

    def __getitem__(self, idx):
        features = self.features[idx,:]
        target = self.target[idx]
        if self.transform:
            features = self.transform(features)
        return features, target

In [None]:
scalers_sel = {}
i = 0
train_dataset = CustomDataset('./splits','sel_X_train_{}.csv'.format(i),target_transform=cus_target_transform)
scalers_sel['sel_X_{}'.format(i)] = train_dataset.scaler
test_dataset = CustomDataset('./splits','sel_X_test_{}.csv'.format(i),target_transform=cus_target_transform,test_transform=scalers_sel['sel_X_{}'.format(i)])
print(type(train_dataset[0][1]),len(train_dataset))#[0][1]))
print(type(test_dataset[0][1]),len(test_dataset))#[0][1]))


### Define loss function and update algorithm - MSE and stochastic gradient descent

In [None]:
def RMSE_tanh(params, feats, targs):
    predictions = batched_tanh_predict(params,feats)

    return jnp.mean(jnp.sqrt((predictions - targs)**2))

def RMSE_relu(params, feats, targs):
    predictions = batched_relu_predict(params,feats)

    return jnp.mean(jnp.sqrt((predictions - targs)**2))

def MAE_tanh(params, feats, targs):
    predictions = batched_tanh_predict(params,feats)

    return jnp.mean(abs(predictions - targs))

def MAE_relu(params, feats, targs):
    predictions = batched_relu_predict(params,feats)

    return jnp.mean(abs(predictions - targs))

loss_functions = {'RMSE_tanh':RMSE_tanh,'RMSE_relu':RMSE_relu,'MAE_tanh':MAE_tanh,'MAE_relu':MAE_relu}

@jit
def update_MAE_relu(params, feats, targs, lr=0.05):
    grads = grad(MAE_relu)(params, feats, targs)

    return jax.tree_map(lambda p,g: p-lr*g,params,grads)

@jit
def update_MAE_tanh(params, feats, targs, lr=0.05):
    grads = grad(MAE_tanh)(params, feats, targs)

    return jax.tree_map(lambda p,g: p-lr*g,params,grads)

@jit
def update_RMSE_relu(params, feats, targs, lr=0.05):
    grads = grad(RMSE_relu)(params, feats, targs)

    return jax.tree_map(lambda p,g: p-lr*g,params,grads)

@jit
def update_RMSE_tanh(params, feats, targs, lr=0.05):
    grads = grad(RMSE_tanh)(params, feats, targs)

    return jax.tree_map(lambda p,g: p-lr*g,params,grads)

### Run Neural net

In [None]:
seed = 0
scalers_sel = {}
def run_sel(i,hl,num_epochs,update,predict):
    train_dataset = CustomDataset('./splits','sel_X_train_{}.csv'.format(i),target_transform=cus_target_transform)
    scalers_sel['sel_X_{}'.format(i)] = train_dataset.scaler
    test_dataset = CustomDataset('./splits','sel_X_test_{}.csv'.format(i),target_transform=cus_target_transform,test_transform=scalers_sel['sel_X_{}'.format(i)])
    
    train_loader = DataLoader(train_dataset,batch_size=128,collate_fn=cus_collate,drop_last=True)
    first_layer = next(iter(train_loader))[0].shape[1]

    key = jax.random.PRNGKey(seed)
    MLP_params = init_MLP([first_layer]+hl+[1],key)
    t_feats, t_targs = list(zip(*test_dataset))
    tmp = update.__name__.lower()

    for epoch in range(num_epochs):

        for feats, targs in train_loader:

            MLP_params = update(MLP_params,feats,targs)
        
        if epoch < 1:
            predictions_local = []
            if 'rmse' in tmp:
                for feats in t_feats:
                    predictions_local.append(predict(MLP_params,feats))
                predictions_local = np.array(predictions_local)
                _loss = np.mean(np.sqrt((predictions_local-t_targs)**2))
            else:
                for feats in t_feats:
                    predictions_local.append(predict(MLP_params,feats))
                _loss = np.mean(abs(np.array(predictions_local)-t_targs))
            

            print('Epoch: {}'.format(epoch+1),'Test Loss: {}'.format(_loss))
                
        if (epoch+1) % 10 == 0:
            predictions_local = []
            if 'rmse' in tmp:
                for feats in t_feats:
                    predictions_local.append(predict(MLP_params,feats))
                _loss = np.mean(np.sqrt((np.array(predictions_local)-t_targs)**2))
            else:
                for feats in t_feats:
                    predictions_local.append(predict(MLP_params,feats))
                _loss = np.mean(abs(np.array(predictions_local)-t_targs))
            

            print('Epoch: {}'.format(epoch+1),'Test Loss: {}'.format(_loss))
    return predictions_local,t_targs,train_dataset.scaler