In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
import optax

from ast import literal_eval
from collections import Counter
from datetime import datetime
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report, accuracy_score

from sherlock.deploy.model import SherlockModel
from sherlock.deploy import helpers
import graphviz

from sklearn.preprocessing import LabelEncoder
import tensorflow as tf

## Download & Prep Data

In [2]:
# from sherlock.features.paragraph_vectors import initialise_pretrained_model, initialise_nltk
# from sherlock.features.preprocessing import (
#     extract_features,
#     convert_string_lists_to_lists,
#     prepare_feature_extraction,
#     load_parquet_values,
# )

# from sherlock.features.word_embeddings import initialise_word_embeddings


# prepare_feature_extraction()
# initialise_word_embeddings()
# initialise_pretrained_model(400)
# initialise_nltk()


## Loading Dataset

In [3]:
start = datetime.now()

#X_train = pd.read_parquet('../data/data/processed/train.parquet')

X_train = pd.read_parquet('../data/data/processed/train.parquet')
y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()
y_train = np.array([x.lower() for x in y_train])

X_validation = pd.read_parquet('../data/data/processed/validation.parquet')
y_validation = pd.read_parquet('../data/data/raw/val_labels.parquet').values.flatten()
y_validation = np.array([x.lower() for x in y_validation])

X_test = pd.read_parquet('../data/data/processed/test.parquet')
y_test = pd.read_parquet('../data/data/raw/test_labels.parquet').values.flatten()
y_test = np.array([x.lower() for x in y_test])

print(f'Load data process took {datetime.now() - start} seconds.')

Load data process took 0:00:17.111887 seconds.


In [4]:
num_classes = len(set(y_train))

encoder = LabelEncoder()
encoder.fit(y_train)

feature_cols = helpers.categorize_features()

X_train_char = X_train[feature_cols["char"]]
X_train_word = X_train[feature_cols["word"]]
X_train_par = X_train[feature_cols["par"]]
X_train_rest = X_train[feature_cols["rest"]]

X_val_char = X_validation[feature_cols["char"]]
X_val_word = X_validation[feature_cols["word"]]
X_val_par = X_validation[feature_cols["par"]]
X_val_rest = X_validation[feature_cols["rest"]]

y_train_int = encoder.transform(y_train)   #(412059,)
y_train_cat = tf.keras.utils.to_categorical(y_train_int) #(412059,78)

y_val_int = encoder.transform(y_validation)
y_val_cat = tf.keras.utils.to_categorical(y_val_int)



  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


In [5]:
j_1 = jnp.array(pd.DataFrame(X_train_char).to_numpy())
j_2 = jnp.array(pd.DataFrame(X_train_word).to_numpy())
j_3 = jnp.array(pd.DataFrame(X_train_par).to_numpy())
j_4 = jnp.array(pd.DataFrame(X_train_rest).to_numpy())

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Models

In [6]:
class RestModel(nn.Module):
    features: Sequence[int]
  
    @nn.compact
    def __call__(self, x):
        
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        return x

In [7]:
class SubModel(nn.Module):
    features: Sequence[int]
    training: bool = True
    
    @nn.compact
    def __call__(self, x):
        
        # batchnormalisation - https://github.com/google/flax/issues/932
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        x = nn.relu(nn.Dense(self.features[0])(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
                
        x = nn.relu(nn.Dense(self.features[1])(x)) 
        
        # todo: add 
        
        
        return x

In [8]:
# PRNGKey Example
print(random.PRNGKey(0))
key1, key2 = random.split(random.PRNGKey(0))
print(key1)
print(random.normal(key1,shape=(1,)))

a = random.normal(key1, (10,))
print(a)

[0 0]
[4146024105  967050713]
[0.14389051]
[-2.6105583   0.03385283  1.0863333  -1.4802988   0.48895672  1.062516
  0.54174834  0.0170228   0.2722685   0.30522448]


In [9]:
class MainModel(nn.Module):
    feature_size: int = 500
    num_classes: int = 78
        
    @nn.compact
    def __call__(self, x1, x2, x3, x4):
       
        # [1] define shape        
        y1 = SubModel([300, 300], name='char_model')(x1)      
        y2 = SubModel([200, 200], name='word_model')(x2)
        y3 = SubModel([400, 400], name='par_model')(x3)
        y4 = RestModel([27], name='rest_model')(x4)
                      
        # [2] concat submodels    
        x = jnp.concatenate((y1, y2, y3, y4), axis=-1)
        
        print("check mainmodel shape")
        print(np.shape(x))
        
        # batchnormalisation
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        # dense 1
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
        
        # dense 2
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dense w/ softmax - todo: check
        x = nn.softmax(nn.Dense(self.feature_size)(x), axis=-1)
        
        return nn.Dense(self.num_classes)(x)


In [10]:
mainmodel = MainModel()
p_main = mainmodel.init(jax.random.PRNGKey(0), jnp.ones((1, 960)), jnp.ones((1, 201)), jnp.ones((1, 400)), jnp.ones((1, 27))) 
#(1, 927)



# ??
# it's not actual training part ( training = optax)
# should I use jnp.ones((1, 960)) instead of actual data for faster computation


#y_main = mainmodel.apply(p_main,j_1,j_2,j_3,j_4)
#(412059, 927)


check mainmodel shape
(1, 927)


In [11]:
#p_main
p_main['params'].keys()

frozen_dict_keys(['char_model', 'word_model', 'par_model', 'rest_model', 'BatchNorm_0', 'Dense_0', 'Dense_1', 'Dense_2', 'Dense_3'])

## Model Graph

In [12]:
# lowered = jax.jit(mainmodel.apply).lower(p_main,j_1,j_2,j_3,j_4)
# comp_dot = graphviz.Source(lowered._xla_computation().as_hlo_dot_graph())
# comp_dot.render('nn_outcome', view=True).replace('\\', '/') 

In [13]:
#print(y_main.shape)

## Training - OPTAX

In [14]:
learning_rate = 0.0001
n_training_steps = 100

# Define an MSE loss function.
def make_mse_func(x_b_1, x_b_2, x_b_3, x_b_4, y_batched):
  def mse(p_main):    
    # Define the squared loss for a single (x, y) pair.
    def squared_error(x1, x2, x3, x4, y):      
      pred = mainmodel.apply(p_main, x1, x2, x3, x4)
      return jnp.inner(y-pred, y-pred) / 2.0  
    
    # Vectorise the squared error and compute the average of the loss.
    return jnp.mean(jax.vmap(squared_error)(x_b_1, x_b_2, x_b_3, x_b_4, y_batched), axis=0)
  return jax.jit(mse)  # `jit` the result.

In [2]:
params = p_main

#dst_x = jnp.concatenate((j_1, j_2, j_3, j_4), axis=-1)
dst_y = jnp.array(y_train_cat)

# Instantiate the sampled loss.
loss = make_mse_func(j_1, j_2, j_3, j_4, dst_y)

optimizer = optax.adam(learning_rate=learning_rate)

# Create optimiser state.
opt_state = optimizer.init(params)

# Compute the gradient of the loss function.
loss_grad_fn = jax.value_and_grad(loss)


NameError: name 'p_main' is not defined

In [3]:
y_train_cat

NameError: name 'y_train_cat' is not defined

## Execute training

In [16]:
# Minimise the loss.
start = datetime.now()

for step in range(500):
    # Compute gradient of the loss.
    loss_val, grads = loss_grad_fn(params)
    # Update the optimiser state, create an update to the params.
    updates, opt_state = optimizer.update(grads, opt_state)
    # Update the parameters.
    params = optax.apply_updates(params, updates)
     
    print(f'Loss[{step}] = {loss_val}')
        
print(f'process took {datetime.now() - start} seconds.')

check mainmodel shape
(927,)


: 

: 

In [None]:
#check param
print(params["params"]["Dense_0"]["kernel"])
print(params["params"]["Dense_0"]["kernel"].shape)

[[ 0.02866286  0.00391353 -0.01072783 ... -0.03264792  0.00024129
  -0.00021848]
 [ 0.00095735 -0.0189361  -0.02946479 ...  0.04912312 -0.00165965
   0.01241112]
 [ 0.03084684  0.04030455  0.01075094 ... -0.02086708 -0.00404731
  -0.06132798]
 ...
 [-0.06701497  0.04543946  0.03028561 ... -0.01327508  0.04897631
   0.03002343]
 [ 0.01275493  0.01944453 -0.04737942 ...  0.02383648 -0.04277041
   0.06001806]
 [ 0.02494416 -0.00382685  0.0251975  ...  0.05030839 -0.0230723
   0.03849551]]
(927, 500)


## Prediction

In [None]:
j_v_1 = jnp.array(pd.DataFrame(X_val_char).to_numpy())
j_v_2 = jnp.array(pd.DataFrame(X_val_word).to_numpy())
j_v_3 = jnp.array(pd.DataFrame(X_val_par).to_numpy())
j_v_4 = jnp.array(pd.DataFrame(X_val_rest).to_numpy())

In [None]:
# print(dst_x.shape) #(412059, 960)
y_pred = mainmodel.apply(params, j_v_1, j_v_2, j_v_3, j_v_4)
print(y_pred.shape) #(137353, 78)

(137353, 78)


## Score

In [None]:
y_pred_classes = helpers._proba_to_classes(y_pred, "sherlock")

print(y_pred_classes)
print(y_pred_classes.shape)

print(f1_score(y_validation, y_pred_classes, average="weighted"))
print(accuracy_score(y_validation, y_pred_classes))

['county' 'genre' 'age' ... 'duration' 'class' 'jockey']
(137353,)
0.6771786681057324
0.7096896318245688


## memo

In [None]:
# TEST Submodel | Dropout separate key issue: Don't need to worry about. The aim is passing the same key to drop out same nodes across multiple devices.

# Can't use one line here like how it's calling submodel in mainmodel?
# https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.CallCompactUnboundModuleError
# submodel = SubModel([300, 300], name='char_model')(j_1)
submodel = SubModel([300, 300], name='char_model')


#p_sub = submodel.init(jax.random.PRNGKey(0),jnp.ones((1, 960)))
p_sub = submodel.init({'params': jax.random.PRNGKey(0), 'dropout': jax.random.PRNGKey(0)},jnp.ones((1, 960)))
y = submodel.apply(p_sub, j_1, rngs={'dropout': jax.random.PRNGKey(0)})

p_sub['params'].keys()

# Do we need to use separate keys like - key1, key2, key3, key4 = random.split(random.PRNGKey(seed), 4) ?
# https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_4_Flax_Zero2Hero_Colab.ipynb
# https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dropout.html