In [39]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score
import torch
import jax.numpy as jnp
from torch.utils.data import Dataset, DataLoader, default_collate
from jax.tree_util import tree_map
from flax import nnx
import optax
from tqdm import tqdm

In [8]:
df = pd.read_csv("titanic-train.csv")
print("shape: ", df.shape)

shape:  (891, 12)


In [6]:
df.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [9]:
def preprocess_titanic_data(df):
    """ Basic preprocessing of Titanic dataset to train ML model

    Args:
        df (pd.DataFrame): Titanic dataset from Kaggle
    """

    df = df.drop(columns=["Name", "Ticket", "Cabin"], axis=1)
    df["Sex"] = df["Sex"].map({'male': 0, 'female': 1})
    df["Embarked"] = df["Embarked"].map({'S': 0, 'C': 1, 'Q': 2})
    df["Age"] = df["Age"].fillna(df["Age"].mean())
    df = df.dropna()
    labels = df.pop("Survived")

    return df, labels

In [10]:
df, labels = preprocess_titanic_data(df)
df.head()

Unnamed: 0,PassengerId,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,1,3,0,22.0,1,0,7.25,0.0
1,2,1,1,38.0,1,0,71.2833,1.0
2,3,3,1,26.0,0,0,7.925,0.0
3,4,1,1,35.0,1,0,53.1,0.0
4,5,3,0,35.0,0,0,8.05,0.0


In [12]:
X_train, X_test, y_train, y_test = train_test_split(df, labels, test_size=0.2)

### baseline model (random forest)

In [14]:
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)

In [19]:
print("accuracy_score = ", round(accuracy_score(y_test, y_pred),4))
print("precision_score = ", round(precision_score(y_test, y_pred),4))
print("recall_score = ", round(recall_score(y_test, y_pred),4))

accuracy_score =  0.8371
precision_score =  0.8276
recall_score =  0.7164


### jax

In [59]:
class TitanicDataset(Dataset):
    def __init__(self, X, y):
        self.features = X.reset_index(drop=True)
        self.labels = y.reset_index(drop=True)

    def __getitem__(self, index):
        x = torch.tensor(self.features.iloc[index].values, dtype=torch.float32)
        y = torch.tensor(self.labels.iloc[index], dtype=torch.float32)

        return x, y
    
    def __len__(self):
        return self.labels.shape[0]

In [60]:
def numpy_collate(batch):
    return tree_map(jnp.asarray, default_collate(batch))

In [61]:
train_ds = TitanicDataset(X_train, y_train)
test_ds = TitanicDataset(X_test, y_test)

In [62]:
train_dataloader_jax = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=numpy_collate)
test_dataloader_jax = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=numpy_collate)

In [63]:
class TitanicNNX(nnx.Module):
    def __init__(self, num_hidden_1, num_hidden_2, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(8, num_hidden_1, rngs=rngs)
        self.dropout = nnx.Dropout(0.01, rngs=rngs)
        self.relu = nnx.leaky_relu
        self.linear2 = nnx.Linear(num_hidden_1, num_hidden_2, rngs=rngs)
        self.linear3 = nnx.Linear(num_hidden_2, 1, use_bias=False, rngs=rngs)

    def __call__(self, x):
        x = self.linear1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = self.relu(x)
        out = self.linear3(x)
        return out

In [64]:
model = TitanicNNX(32, 16, rngs=nnx.Rngs(0))
nnx.display(model)

In [65]:
def train(model, train_dataloader, test_dataloader, num_epochs):
    optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate=0.01))

    for epoch in (pbar := tqdm(range(num_epochs))):
        pbar.set_description(f"Epoch {epoch}")
        model.train()
        for batch in train_dataloader:
            train_step(model, optimizer, batch)
        
        pbar.set_postfix(train_accuracy=eval(model, train_dataloader), eval_accuracy=eval(model, test_dataloader))

@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model):
        logits = model(batch[0])
        loss = optax.sigmoid_binary_cross_entropy(logits.squeeze(), batch[1]).mean()

        return loss
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model)
    optimizer.update(grads)

def eval(model, test_dataloader):
    total = 0
    num_correct = 0
    model.eval()
    for batch in test_dataloader:
        res = eval_step(model, batch)
        total += res.shape[0]
        num_correct += jnp.sum(res)

    return num_correct / total

@nnx.jit
def eval_step(model, batch):
    logits = model(batch[0])
    logits = logits.squeeze()
    preds = jnp.round(nnx.sigmoid(logits))

    return preds == batch[1]

In [66]:
train(model, train_dataloader_jax, test_dataloader_jax, num_epochs=500)

Epoch 499: 100%|██████████| 500/500 [00:21<00:00, 23.44it/s, eval_accuracy=0.7696629, train_accuracy=0.82137835] 
