<a href="https://colab.research.google.com/github/guiOsorio/Learning_JAX/blob/master/TitanicJAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Build JAX model for Titanic dataset

In [1]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, grad, value_and_grad
from jax import random
import jax
from jax.scipy.special import logsumexp

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import time
import matplotlib.pyplot as plt
import pandas as pd

#### LOAD THE DATA

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [3]:
# Load training set
trainval_path = '/content/drive/MyDrive/titanic_train.csv'
trainval_df = pd.read_csv(trainval_path)

trainval_df.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [4]:
# Load testing set
test_path = '/content/drive/MyDrive/titanic_test.csv'
test_df = pd.read_csv(test_path)

test_df.head()

Unnamed: 0,PassengerId,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,892,3,"Kelly, Mr. James",male,34.5,0,0,330911,7.8292,,Q
1,893,3,"Wilkes, Mrs. James (Ellen Needs)",female,47.0,1,0,363272,7.0,,S
2,894,2,"Myles, Mr. Thomas Francis",male,62.0,0,0,240276,9.6875,,Q
3,895,3,"Wirz, Mr. Albert",male,27.0,0,0,315154,8.6625,,S
4,896,3,"Hirvonen, Mrs. Alexander (Helga E Lindqvist)",female,22.0,1,1,3101298,12.2875,,S


#### EXPLORE THE TRAINING SET AND CLEAN/PREPROCESS

In [5]:
trainval_df.describe()

Unnamed: 0,PassengerId,Survived,Pclass,Age,SibSp,Parch,Fare
count,891.0,891.0,891.0,714.0,891.0,891.0,891.0
mean,446.0,0.383838,2.308642,29.699118,0.523008,0.381594,32.204208
std,257.353842,0.486592,0.836071,14.526497,1.102743,0.806057,49.693429
min,1.0,0.0,1.0,0.42,0.0,0.0,0.0
25%,223.5,0.0,2.0,20.125,0.0,0.0,7.9104
50%,446.0,0.0,3.0,28.0,0.0,0.0,14.4542
75%,668.5,1.0,3.0,38.0,1.0,0.0,31.0
max,891.0,1.0,3.0,80.0,8.0,6.0,512.3292


In [6]:
trainval_df.count()

PassengerId    891
Survived       891
Pclass         891
Name           891
Sex            891
Age            714
SibSp          891
Parch          891
Ticket         891
Fare           891
Cabin          204
Embarked       889
dtype: int64

In [7]:
test_df.count()

PassengerId    418
Pclass         418
Name           418
Sex            418
Age            332
SibSp          418
Parch          418
Ticket         418
Fare           417
Cabin           91
Embarked       418
dtype: int64

In [8]:
# Drop 'Cabin' column from training and test sets due to too many blank columns + drop 'PassengerId', 'Ticket' and 'Name'
trainval_df = trainval_df.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
test_df = test_df.drop(['Name', 'Ticket', 'Cabin'], axis=1)

#### TRANSFORMS

In [9]:
# Combine 'SibSp' and 'Parch' columns into one (family members)
trainval_df['Fam_members'] = trainval_df['SibSp'] + trainval_df['Parch']
trainval_df = trainval_df.drop(['SibSp', 'Parch'], axis=1)

test_df['Fam_members'] = test_df['SibSp'] + test_df['Parch']
test_df = test_df.drop(['SibSp', 'Parch'], axis=1)

trainval_df.head()

Unnamed: 0,Survived,Pclass,Sex,Age,Fare,Embarked,Fam_members
0,0,3,male,22.0,7.25,S,1
1,1,1,female,38.0,71.2833,C,1
2,1,3,female,26.0,7.925,S,0
3,1,1,female,35.0,53.1,S,1
4,0,3,male,35.0,8.05,S,0


In [10]:
# Convert 'Sex' variable to binary (1 if male, 0 if female)
trainval_df['Sex'] = pd.get_dummies(trainval_df['Sex'])['male']

test_df['Sex'] = pd.get_dummies(test_df['Sex'])['male']

trainval_df.head()

Unnamed: 0,Survived,Pclass,Sex,Age,Fare,Embarked,Fam_members
0,0,3,1,22.0,7.25,S,1
1,1,1,0,38.0,71.2833,C,1
2,1,3,0,26.0,7.925,S,0
3,1,1,0,35.0,53.1,S,1
4,0,3,1,35.0,8.05,S,0


In [11]:
# Normalize 'Age' variable

age_appended = trainval_df['Age'].append(test_df['Age'])
min_age = age_appended.min()
max_age = age_appended.max()
mean_age = age_appended.mean()

## Fill NaN values with mean age
trainval_df['Age'] = trainval_df['Age'].fillna(mean_age)
test_df['Age'] = test_df['Age'].fillna(mean_age)

## Normalize
trainval_df['Age'] = (trainval_df['Age'] - min_age) / (max_age - min_age)
test_df['Age'] = (test_df['Age'] - min_age) / (max_age - min_age)

trainval_df['Age']

0      0.273456
1      0.473882
2      0.323563
3      0.436302
4      0.436302
         ...   
886    0.336089
887    0.235876
888    0.372180
889    0.323563
890    0.398722
Name: Age, Length: 891, dtype: float64

In [12]:
# Normalize 'Fam_members' variable

fam_appended = trainval_df['Fam_members'].append(test_df['Fam_members'])
min_fammembers = fam_appended.min()
max_fammembers = fam_appended.max()

## Normalize
trainval_df['Fam_members'] = (trainval_df['Fam_members'] - min_fammembers) / (max_fammembers - min_fammembers)
test_df['Fam_members'] = (test_df['Fam_members'] - min_fammembers) / (max_fammembers - min_fammembers)

trainval_df['Fam_members']

0      0.1
1      0.1
2      0.0
3      0.1
4      0.0
      ... 
886    0.0
887    0.0
888    0.3
889    0.0
890    0.0
Name: Fam_members, Length: 891, dtype: float64

In [13]:
# Standardize 'Fare' variable

fare_appended = trainval_df['Fare'].append(test_df['Fare'])
fare_mean = fare_appended.mean()
fare_median = fare_appended.median()
fare_std = fare_appended.std()

# Fill NaN value in test set with median fare
test_df['Fare'] = test_df['Fare'].fillna(fare_median)

# Normalize - df = (df - df.mean())/df.std()
trainval_df['Fare'] = (trainval_df['Fare'] - fare_mean) / fare_std
test_df['Fare'] = (test_df['Fare'] - fare_mean) / fare_std

trainval_df['Fare']

0     -0.503210
1      0.733941
2     -0.490169
3      0.382632
4     -0.487754
         ...   
886   -0.392117
887   -0.063670
888   -0.190219
889   -0.063670
890   -0.493550
Name: Fare, Length: 891, dtype: float64

In [14]:
# Fill 'Embarked' NaNs
embarked_appended = trainval_df['Embarked'].append(test_df['Embarked'])
embarked_mode = embarked_appended.mode()
trainval_df['Embarked'] = trainval_df['Embarked'].fillna(embarked_mode[0])

# One hot encode 'Embarked' variable
trainval_df = pd.concat([trainval_df, pd.get_dummies(trainval_df['Embarked'], prefix='Embarked')], axis=1)
trainval_df = trainval_df.drop(['Embarked'], axis=1)
test_df = pd.concat([test_df, pd.get_dummies(test_df['Embarked'], prefix='Embarked')], axis=1)
test_df = test_df.drop(['Embarked'], axis=1)

trainval_df.head()

Unnamed: 0,Survived,Pclass,Sex,Age,Fare,Fam_members,Embarked_C,Embarked_Q,Embarked_S
0,0,3,1,0.273456,-0.50321,0.1,0,0,1
1,1,1,0,0.473882,0.733941,0.1,1,0,0
2,1,3,0,0.323563,-0.490169,0.0,0,0,1
3,1,1,0,0.436302,0.382632,0.1,0,0,1
4,0,3,1,0.436302,-0.487754,0.0,0,0,1


In [15]:
# One hot encode 'Pclass' variable (no NaN values)
trainval_df = pd.concat([trainval_df, pd.get_dummies(trainval_df['Pclass'], prefix='Pclass')], axis=1)
trainval_df = trainval_df.drop(['Pclass'], axis=1)
test_df = pd.concat([test_df, pd.get_dummies(test_df['Pclass'], prefix='Pclass')], axis=1)
test_df = test_df.drop(['Pclass'], axis=1)

trainval_df.head()

Unnamed: 0,Survived,Sex,Age,Fare,Fam_members,Embarked_C,Embarked_Q,Embarked_S,Pclass_1,Pclass_2,Pclass_3
0,0,1,0.273456,-0.50321,0.1,0,0,1,0,0,1
1,1,0,0.473882,0.733941,0.1,1,0,0,1,0,0
2,1,0,0.323563,-0.490169,0.0,0,0,1,0,0,1
3,1,0,0.436302,0.382632,0.1,0,0,1,1,0,0
4,0,1,0.436302,-0.487754,0.0,0,0,1,0,0,1


In [16]:
# Check NaNs in both trainval and test sets (make sure all values are filled)
print(trainval_df.isna().sum())
print(test_df.isna().sum())

## Looks good

Survived       0
Sex            0
Age            0
Fare           0
Fam_members    0
Embarked_C     0
Embarked_Q     0
Embarked_S     0
Pclass_1       0
Pclass_2       0
Pclass_3       0
dtype: int64
PassengerId    0
Sex            0
Age            0
Fare           0
Fam_members    0
Embarked_C     0
Embarked_Q     0
Embarked_S     0
Pclass_1       0
Pclass_2       0
Pclass_3       0
dtype: int64


In [17]:
# Split training set into training and validation (70/30)
split = round(len(trainval_df) * 0.8)

train_df = trainval_df[:split]
val_df = trainval_df[split:]

print(len(train_df), len(val_df))

train_df.head()

713 178


Unnamed: 0,Survived,Sex,Age,Fare,Fam_members,Embarked_C,Embarked_Q,Embarked_S,Pclass_1,Pclass_2,Pclass_3
0,0,1,0.273456,-0.50321,0.1,0,0,1,0,0,1
1,1,0,0.473882,0.733941,0.1,1,0,0,1,0,0
2,1,0,0.323563,-0.490169,0.0,0,0,1,0,0,1
3,1,0,0.436302,0.382632,0.1,0,0,1,1,0,0
4,0,1,0.436302,-0.487754,0.0,0,0,1,0,0,1


#### NEURAL NET

In [25]:
# Initialize parameters
seed = 0

def init_params(layers_size, parent_key):

  params = []
  # From a parent key, generate different keys for each layer
  keys = jax.random.split(parent_key, num=len(layers_size)-1) # understand better what split does/why is it useful

  # Set sizes of layers in the model (inputs to layers and outputs to layers)
  in_layers = layers_size[:-1]
  out_of_layers = layers_size[1:]

  for in_layer, out_of_layer, key in zip(in_layers, out_of_layers, keys):
    weights_key, bias_key = jax.random.split(key)

    # Initialize params to be an array [weights, bias]
    #where the weights are n rows (number of neurons, outputs to layer, inputs to next layer) x m columns (number of inputs to layer, outputs from previous layer/features)
    #bias are n rows x 1 column (one bias per neuron)
    params.append([
        0.01*jax.random.normal(weights_key, shape=(out_of_layer, in_layer)) # n x m matrix
        ,0.01*jax.random.normal(bias_key, shape=(out_of_layer,)) # vector with n values
    ])

  return params
key = jax.random.PRNGKey(seed)
params = init_params([10,8,2], key)

# Go through each layer of the initialized params and check if the shape is the expected one
jax.tree_map(lambda x: x.shape, params)

[[(8, 10), (8,)], [(2, 8), (2,)]]

In [19]:
# Dataset class
class CustomDataset(Dataset):
  def __init__(self, data, targets):
    self.data = data
    self.targets = targets

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    current_sample = self.data[idx, :]
    current_target = self.targets[idx]
    return {
        'sample': torch.tensor(current_sample),
        'target': torch.tensor(current_target)
    }

# test_dset = CustomDataset()

In [20]:
# Dataloaders with custom_collate + custom transforms


In [21]:
# Predict, loss, accuracy and update

In [22]:
# Train network with predictions on validation set

In [23]:
# Confusion matrix on validation set