<a href="https://www.kaggle.com/code/yno3fm36xqnnc8/solving-tabular-playground-jan-2021-with-flax?scriptVersionId=141492705" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import pandas as pd
import numpy as np

import jax
import jax.numpy as jnp
import flax.linen as nn
import flax

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from typing import Sequence, List, Any, Optional



The goal of this challenge is simply to predict the value of the `target` variable given 14 regressors. This challenge is a regression problem which we'll solve by using a simple MLP, with the `jax` and `flax` libraries helping us out. Growing out of work by the Google Brain team, `flax` is a higher-level library that utilizes `jax` to implement common neural network algorithms.

First off, let's get the data loaded and check some common summary statistics for the features.

In [2]:
train_data = pd.read_csv("/kaggle/input/tabular-playground-series-jan-2021/train.csv").set_index('id')
train_data.describe()

Unnamed: 0,cont1,cont2,cont3,cont4,cont5,cont6,cont7,cont8,cont9,cont10,cont11,cont12,cont13,cont14,target
count,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0,300000.0
mean,0.506873,0.497898,0.521557,0.515683,0.502022,0.526515,0.48789,0.525163,0.459857,0.520532,0.483926,0.506877,0.553442,0.503713,7.905661
std,0.203976,0.228159,0.20077,0.233035,0.220701,0.217909,0.181096,0.216221,0.196685,0.201854,0.220082,0.218947,0.22973,0.208238,0.733071
min,-0.082263,-0.031397,0.020967,0.152761,0.276377,0.066166,-0.097666,0.21726,-0.240604,-0.085046,0.083277,0.088635,0.02995,0.166367,0.0
25%,0.343078,0.31917,0.344096,0.294935,0.284108,0.356163,0.3466,0.341486,0.330832,0.375465,0.300474,0.310166,0.350472,0.308673,7.329367
50%,0.484005,0.553209,0.551471,0.48288,0.451733,0.470988,0.466825,0.48346,0.416843,0.458877,0.441916,0.486599,0.487707,0.431845,7.940571
75%,0.643789,0.731263,0.648315,0.748705,0.67066,0.694043,0.581292,0.68525,0.575041,0.700292,0.679128,0.694453,0.768479,0.712653,8.470084
max,1.016227,0.859697,1.006955,1.010402,1.034261,1.043858,1.066167,1.024427,1.004114,1.199951,1.02262,1.049025,0.977845,0.868506,10.267569


Now we see our first example of a `flax` model, in this case a simple MLP. Its sole argument is `features`, a sequence of integers specifying the size of the dense linear layers in the network.

In [3]:
class MLP(nn.Module):
    features: Sequence[int]
            
    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

Now let's actually create a model instance.

In [4]:
model = MLP(features=[12, 8, 8, 8, 4, 4, 1])

variables = model.init(jax.random.PRNGKey(0), jnp.ones((1,14)))
vpredict = jax.vmap(model.apply, (None, 0))

We define the loss function, MSE in this case, and then define its gradient.

In [5]:
def loss(variables: flax.core.frozen_dict.FrozenDict, X: jnp.array, y: jnp.array):
    return jnp.mean(jnp.square(y - vpredict(variables, X)))
loss_grad_fn = jax.value_and_grad(loss)

Below is the training loop. In this notebook I'm doing this by hand, but in real life it may be better/more convenient to use a library like `optax`.

In [6]:
learning_rate = 1e-2
last_value = None
variables = flax.core.frozen_dict.unfreeze(variables)
data = train_data.copy().to_numpy()
rng = np.random.default_rng()
for i in range(1000):
    X = rng.choice(data, 1000, axis=0, replace=False)
    X_train, y_train = X[:,:-1], X[:,-1]
    value, grads = loss_grad_fn(variables, X_train, y_train)
    
    if last_value is not None and abs(last_value - value) < 0.001 and value < 1.0:
        break
    
    last_value = value
    if i % 10 == 0:
        print(value)
    for key in grads['params']:
        for k in grads['params'][key]:
            variables['params'][key][k] -= learning_rate * grads['params'][key][k]

variables = flax.core.frozen_dict.freeze(variables)

62.60427
23.943077


Finally, let's predict the target value for the test dataset and prepare the submission.

In [7]:
test_data = pd.read_csv("/kaggle/input/tabular-playground-series-jan-2021/test.csv").set_index('id')
test_data['target'] = vpredict(variables, jnp.array(test_data.to_numpy()))

test_data

Unnamed: 0_level_0,cont1,cont2,cont3,cont4,cont5,cont6,cont7,cont8,cont9,cont10,cont11,cont12,cont13,cont14,target
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,0.353600,0.738780,0.600939,0.293377,0.285691,0.458006,0.620704,0.422249,0.369203,0.435727,0.550540,0.699134,0.286864,0.364515,7.427158
2,0.907222,0.189756,0.215531,0.869915,0.301333,0.528958,0.390351,0.521112,0.794779,0.798580,0.446475,0.449037,0.916964,0.513002,7.240384
6,0.179287,0.355353,0.623972,0.437812,0.282476,0.320826,0.386789,0.776422,0.222268,0.229102,0.211913,0.222651,0.327164,0.827941,6.783124
7,0.359385,0.181049,0.551368,0.206386,0.280763,0.482076,0.506677,0.362793,0.379737,0.345686,0.445276,0.518485,0.299028,0.598166,6.783523
10,0.335791,0.682607,0.676481,0.219465,0.282861,0.581721,0.748639,0.350158,0.448915,0.506878,0.817721,0.805895,0.790591,0.249275,7.811783
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
499984,0.353856,0.677578,0.550852,0.869612,0.957635,0.255054,0.289138,0.635979,0.271399,0.282455,0.217169,0.219088,0.373261,0.272479,7.257768
499985,0.243209,0.135627,0.218393,0.792798,0.547639,0.433520,0.549540,0.650107,0.453787,0.459689,0.450424,0.511176,0.318334,0.395747,7.143597
499987,0.506973,0.683893,0.533434,0.192957,0.314381,0.358604,0.554455,0.267105,0.396101,0.445390,0.382656,0.397978,0.381235,0.369464,7.238574
499988,0.347870,0.553112,0.495284,0.861500,0.816914,0.298478,0.275964,0.265841,0.334250,0.252635,0.213589,0.285223,0.336772,0.388505,7.049764


In [8]:
test_data[['target']].to_csv("submission.csv")