In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
import optax
from sklearn.metrics import accuracy_score

Loading Dataset

In [2]:
from ast import literal_eval
from collections import Counter
from datetime import datetime
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report
from sherlock.deploy.model import SherlockModel

start = datetime.now()

X_train = pd.read_parquet('../data/data/processed/train.parquet')
y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()
y_train = np.array([x.lower() for x in y_train])

X_validation = pd.read_parquet('../data/data/processed/validation.parquet')
y_validation = pd.read_parquet('../data/data/raw/val_labels.parquet').values.flatten()
y_validation = np.array([x.lower() for x in y_validation])

X_test = pd.read_parquet('../data/data/processed/test.parquet')
y_test = pd.read_parquet('../data/data/raw/test_labels.parquet').values.flatten()
y_test = np.array([x.lower() for x in y_test])

print(f'Load data process took {datetime.now() - start} seconds.')

2022-06-27 17:02:56.690610: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


Load data process took 0:00:17.475816 seconds.


In [3]:
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf

from sherlock.deploy import helpers

num_classes = len(set(y_train))

encoder = LabelEncoder()
encoder.fit(y_train)

feature_cols = helpers.categorize_features()

X_train_char = X_train[feature_cols["char"]]
X_train_word = X_train[feature_cols["word"]]
X_train_par = X_train[feature_cols["par"]]
X_train_rest = X_train[feature_cols["rest"]]

X_val_char = X_validation[feature_cols["char"]]
X_val_word = X_validation[feature_cols["word"]]
X_val_par = X_validation[feature_cols["par"]]
X_val_rest = X_validation[feature_cols["rest"]]

y_train_int = encoder.transform(y_train)   #(412059,)
y_train_cat = tf.keras.utils.to_categorical(y_train_int) #(412059,78)
y_val_int = encoder.transform(y_validation)
y_val_cat = tf.keras.utils.to_categorical(y_val_int)



  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


In [4]:
len(feature_cols["word"])

201

960 single char level model training

Dense_0: {
    bias: (300,),
    kernel: (960, 300),
},
Dense_1: {
    bias: (300,),
    kernel: (300, 300),
},
Dense_2: {
    bias: (78,),
    kernel: (300, 78),
    
1 layer + 2*300 layers + 78 output layer


In [5]:
import graphviz

class SubModel(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    
    return x
    
class MainModel(nn.Module):
    feature_size: int = 500
    
    @nn.compact
    def __call__(self, concated_X):
        #x = jnp.concatenate((x1, x2, x3, x4), axis=-1)
        return nn.Dense(self.feature_size)(concated_X)

    
'''model = SubModel([300, 300, 78]) 
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng)
params = model.init(rng2, jax.random.normal(rng1, (960,)))'''


char_model_input = len(feature_cols["char"])
word_model_input = len(feature_cols["word"])
par_model_input = len(feature_cols["par"])
rest_model_input = len(feature_cols["rest"])

# Submodels to concat

# layers: (960,300) (300,300)
char_model = SubModel([300, 300])
word_model = SubModel([200, 200])
par_model = SubModel([400, 400])
rest_model = SubModel([27])
mainmodel = MainModel()

# Dummy data to initialize
b_char = jnp.ones((1, 960))
b_word = jnp.ones((1, 201))
b_par  = jnp.ones((1, 400))
b_rest = jnp.ones((1, 27))

# initialize dense layer
p_char = char_model.init(jax.random.PRNGKey(0), b_char)
p_word = word_model.init(jax.random.PRNGKey(0), b_word)
p_par  = par_model.init(jax.random.PRNGKey(0), b_par)
p_rest = rest_model.init(jax.random.PRNGKey(0), b_rest)

j_r = jnp.concatenate((jnp.ones((1, 300)), jnp.ones((1, 200)), jnp.ones((1, 400)), jnp.ones((1, 27))), axis=-1)

p_main = mainmodel.init(jax.random.PRNGKey(0),j_r) 


def main(x1, x2, x3, x4, p1, p2, p3, p4, pmain):
    y1 = char_model.apply(p1, x1)
    y2 = word_model.apply(p2, x2)
    y3 = par_model.apply(p3, x3)
    y4 = rest_model.apply(p4, x4)
    
    j_r = jnp.concatenate((y1,y2,y3,y4), axis=-1)
    
    return mainmodel.apply(pmain, j_r)

lowered = jax.jit(main).lower(b_char, b_word, b_par, b_rest, p_char, p_word, p_par, p_rest, p_main)

comp_dot = graphviz.Source(lowered._xla_computation().as_hlo_dot_graph())



In [6]:
comp_dot.render('nn_outcome', view=True).replace('\\', '/') 

'nn_outcome.pdf'

OPTAX

In [7]:
learning_rate = 0.0001
n_training_steps = 100

# Define an MSE loss function.
def make_mse_func(x_batched, y_batched): 
  print(len(x_batched[0]))
  def mse(p_main):    
    # Define the squared loss for a single (x, y) pair.
    def squared_error(x, y):      
      pred = mainmodel.apply(p_main, x)
      return jnp.inner(y-pred, y-pred) / 2.0  
    
    # Vectorise the squared error and compute the average of the loss.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
  return jax.jit(mse)  # `jit` the result.

Using training data:  df -> jnp

http://bl.ocks.org/miguelusque/raw/f44a8e729896a96d0a3e4b07b5176af4/#numpy-jax

In [8]:
params = p_main

j_1 = jnp.array(pd.DataFrame(X_train_char).to_numpy())
j_2 = jnp.array(pd.DataFrame(X_train_word).to_numpy())
j_3 = jnp.array(pd.DataFrame(X_train_par).to_numpy())
j_4 = jnp.array(pd.DataFrame(X_train_rest).to_numpy())

dst_x = jnp.concatenate((j_1, j_2, j_3, j_4), axis=-1)
j_r = jnp.array()


dst_y = jnp.array(y_train_cat)

# Instantiate the sampled loss.

loss = make_mse_func(dst_x, dst_y)

optimizer = optax.adam(learning_rate=learning_rate)

# Create optimiser state.
opt_state = optimizer.init(params)
# Compute the gradient of the loss function.
loss_grad_fn = jax.value_and_grad(loss)


Error: no "view" mailcap rules found for type "application/pdf"


1588


In [9]:
dst_x[0]
len(dst_x[0]) #1588
len(dst_x) #412059

SyntaxError: invalid syntax (653407623.py, line 2)

Execute training

In [10]:
# Minimise the loss.
start = datetime.now()

for step in range(10):
    # Compute gradient of the loss.
    loss_val, grads = loss_grad_fn(params)
    # Update the optimiser state, create an update to the params.
    updates, opt_state = optimizer.update(grads, opt_state)
    # Update the parameters.
    params = optax.apply_updates(params, updates)
    if step%100 == 0: 
        print(f'Loss[{step}] = {loss_val}')
        
print(f'Load data process took {datetime.now() - start} seconds.')

ScopeParamShapeError: Inconsistent shapes between value and initializer for parameter "kernel" in "/Dense_0": (927, 500), (1588, 500). (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamShapeError)

In [None]:
#check param
print(params["params"]["Dense_0"]["kernel"])
print(params["params"]["Dense_0"]["kernel"].shape)

In [None]:
## Training set
model = mainmodel

## predict
print(dst_x.shape) #(412059, 960)
y_pred = model.apply(params, dst_x)
print(y.shape) #(412059, 78)

## score
y_pred_classes = helpers._proba_to_classes(y_pred, "sherlock")

print(y_pred_classes)
print(y_pred_classes.shape)

print(f1_score(y_train, y_pred_classes, average="weighted"))
print(accuracy_score(y_train, y_pred_classes))

In [None]:
print(X_train_char.shape) #(412059, 960)

In [None]:
## Validation set

## predict
print(X_val_char.shape) #(412059, 960)
y_val = model.apply(params, X_val_char)
print(y_val.shape) #(412059, 78)

## score
y_val_pred_classes = helpers._proba_to_classes(y_val, "sherlock")

print(y_val_pred_classes)

print(f1_score(y_validation, y_val_pred_classes, average="weighted"))
print(accuracy_score(y_validation, y_val_pred_classes))

Other feature models